transport.go raw

   1  // Copyright 2023 Google LLC
   2  //
   3  // Licensed under the Apache License, Version 2.0 (the "License");
   4  // you may not use this file except in compliance with the License.
   5  // You may obtain a copy of the License at
   6  //
   7  //      http://www.apache.org/licenses/LICENSE-2.0
   8  //
   9  // Unless required by applicable law or agreed to in writing, software
  10  // distributed under the License is distributed on an "AS IS" BASIS,
  11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12  // See the License for the specific language governing permissions and
  13  // limitations under the License.
  14  
  15  package httptransport
  16  
  17  import (
  18  	"context"
  19  	"crypto/tls"
  20  	"net"
  21  	"net/http"
  22  	"os"
  23  	"time"
  24  
  25  	"cloud.google.com/go/auth"
  26  	"cloud.google.com/go/auth/credentials"
  27  	"cloud.google.com/go/auth/internal"
  28  	"cloud.google.com/go/auth/internal/transport"
  29  	"cloud.google.com/go/auth/internal/transport/cert"
  30  	"cloud.google.com/go/auth/internal/transport/headers"
  31  	"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
  32  	"golang.org/x/net/http2"
  33  )
  34  
  35  const (
  36  	quotaProjectHeaderKey = "X-goog-user-project"
  37  )
  38  
  39  func newTransport(base http.RoundTripper, opts *Options) (http.RoundTripper, error) {
  40  	var headers = opts.Headers
  41  	ht := &headerTransport{
  42  		base:    base,
  43  		headers: headers,
  44  	}
  45  	var trans http.RoundTripper = ht
  46  	trans = addOpenTelemetryTransport(trans, opts)
  47  	switch {
  48  	case opts.DisableAuthentication:
  49  		// Do nothing.
  50  	case opts.APIKey != "":
  51  		qp := internal.GetQuotaProject(nil, opts.Headers.Get(quotaProjectHeaderKey))
  52  		if qp != "" {
  53  			if headers == nil {
  54  				headers = make(map[string][]string, 1)
  55  			}
  56  			headers.Set(quotaProjectHeaderKey, qp)
  57  		}
  58  		trans = &apiKeyTransport{
  59  			Transport: trans,
  60  			Key:       opts.APIKey,
  61  		}
  62  	default:
  63  		var creds *auth.Credentials
  64  		if opts.Credentials != nil {
  65  			creds = opts.Credentials
  66  		} else {
  67  			var err error
  68  			creds, err = credentials.DetectDefault(opts.resolveDetectOptions())
  69  			if err != nil {
  70  				return nil, err
  71  			}
  72  		}
  73  		qp, err := creds.QuotaProjectID(context.Background())
  74  		if err != nil {
  75  			return nil, err
  76  		}
  77  		if qp != "" {
  78  			if headers == nil {
  79  				headers = make(map[string][]string, 1)
  80  			}
  81  			// Don't overwrite user specified quota
  82  			if v := headers.Get(quotaProjectHeaderKey); v == "" {
  83  				headers.Set(quotaProjectHeaderKey, qp)
  84  			}
  85  		}
  86  		var skipUD bool
  87  		if iOpts := opts.InternalOptions; iOpts != nil {
  88  			skipUD = iOpts.SkipUniverseDomainValidation
  89  		}
  90  		creds.TokenProvider = auth.NewCachedTokenProvider(creds.TokenProvider, nil)
  91  		trans = &authTransport{
  92  			base:                         trans,
  93  			creds:                        creds,
  94  			clientUniverseDomain:         opts.UniverseDomain,
  95  			skipUniverseDomainValidation: skipUD,
  96  		}
  97  	}
  98  	return trans, nil
  99  }
 100  
 101  // defaultBaseTransport returns the base HTTP transport.
 102  // On App Engine, this is urlfetch.Transport.
 103  // Otherwise, use a default transport, taking most defaults from
 104  // http.DefaultTransport.
 105  // If TLSCertificate is available, set TLSClientConfig as well.
 106  func defaultBaseTransport(clientCertSource cert.Provider, dialTLSContext func(context.Context, string, string) (net.Conn, error)) http.RoundTripper {
 107  	defaultTransport, ok := http.DefaultTransport.(*http.Transport)
 108  	if !ok {
 109  		defaultTransport = transport.BaseTransport()
 110  	}
 111  	trans := defaultTransport.Clone()
 112  	trans.MaxIdleConnsPerHost = 100
 113  
 114  	if clientCertSource != nil {
 115  		trans.TLSClientConfig = &tls.Config{
 116  			GetClientCertificate: clientCertSource,
 117  		}
 118  	}
 119  	if dialTLSContext != nil {
 120  		// If DialTLSContext is set, TLSClientConfig wil be ignored
 121  		trans.DialTLSContext = dialTLSContext
 122  	}
 123  
 124  	// Configures the ReadIdleTimeout HTTP/2 option for the
 125  	// transport. This allows broken idle connections to be pruned more quickly,
 126  	// preventing the client from attempting to re-use connections that will no
 127  	// longer work.
 128  	http2Trans, err := http2.ConfigureTransports(trans)
 129  	if err == nil {
 130  		http2Trans.ReadIdleTimeout = time.Second * 31
 131  	}
 132  
 133  	return trans
 134  }
 135  
 136  type apiKeyTransport struct {
 137  	// Key is the API Key to set on requests.
 138  	Key string
 139  	// Transport is the underlying HTTP transport.
 140  	// If nil, http.DefaultTransport is used.
 141  	Transport http.RoundTripper
 142  }
 143  
 144  func (t *apiKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
 145  	newReq := *req
 146  	args := newReq.URL.Query()
 147  	args.Set("key", t.Key)
 148  	newReq.URL.RawQuery = args.Encode()
 149  	return t.Transport.RoundTrip(&newReq)
 150  }
 151  
 152  type headerTransport struct {
 153  	headers http.Header
 154  	base    http.RoundTripper
 155  }
 156  
 157  func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
 158  	rt := t.base
 159  	newReq := *req
 160  	newReq.Header = make(http.Header)
 161  	for k, vv := range req.Header {
 162  		newReq.Header[k] = vv
 163  	}
 164  
 165  	for k, v := range t.headers {
 166  		newReq.Header[k] = v
 167  	}
 168  
 169  	return rt.RoundTrip(&newReq)
 170  }
 171  
 172  func addOpenTelemetryTransport(trans http.RoundTripper, opts *Options) http.RoundTripper {
 173  	if opts.DisableTelemetry {
 174  		return trans
 175  	}
 176  	return otelhttp.NewTransport(trans)
 177  }
 178  
 179  type authTransport struct {
 180  	creds                        *auth.Credentials
 181  	base                         http.RoundTripper
 182  	clientUniverseDomain         string
 183  	skipUniverseDomainValidation bool
 184  }
 185  
 186  // getClientUniverseDomain returns the default service domain for a given Cloud
 187  // universe, with the following precedence:
 188  //
 189  // 1. A non-empty option.WithUniverseDomain or similar client option.
 190  // 2. A non-empty environment variable GOOGLE_CLOUD_UNIVERSE_DOMAIN.
 191  // 3. The default value "googleapis.com".
 192  //
 193  // This is the universe domain configured for the client, which will be compared
 194  // to the universe domain that is separately configured for the credentials.
 195  func (t *authTransport) getClientUniverseDomain() string {
 196  	if t.clientUniverseDomain != "" {
 197  		return t.clientUniverseDomain
 198  	}
 199  	if envUD := os.Getenv(internal.UniverseDomainEnvVar); envUD != "" {
 200  		return envUD
 201  	}
 202  	return internal.DefaultUniverseDomain
 203  }
 204  
 205  // RoundTrip authorizes and authenticates the request with an
 206  // access token from Transport's Source. Per the RoundTripper contract we must
 207  // not modify the initial request, so we clone it, and we must close the body
 208  // on any errors that happens during our token logic.
 209  func (t *authTransport) RoundTrip(req *http.Request) (*http.Response, error) {
 210  	reqBodyClosed := false
 211  	if req.Body != nil {
 212  		defer func() {
 213  			if !reqBodyClosed {
 214  				req.Body.Close()
 215  			}
 216  		}()
 217  	}
 218  	token, err := t.creds.Token(req.Context())
 219  	if err != nil {
 220  		return nil, err
 221  	}
 222  	if !t.skipUniverseDomainValidation && token.MetadataString("auth.google.tokenSource") != "compute-metadata" {
 223  		credentialsUniverseDomain, err := t.creds.UniverseDomain(req.Context())
 224  		if err != nil {
 225  			return nil, err
 226  		}
 227  		if err := transport.ValidateUniverseDomain(t.getClientUniverseDomain(), credentialsUniverseDomain); err != nil {
 228  			return nil, err
 229  		}
 230  	}
 231  	req2 := req.Clone(req.Context())
 232  	headers.SetAuthHeader(token, req2)
 233  	reqBodyClosed = true
 234  	return t.base.RoundTrip(req2)
 235  }
 236