transport.go raw
1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package httptransport
16
17 import (
18 "context"
19 "crypto/tls"
20 "net"
21 "net/http"
22 "os"
23 "time"
24
25 "cloud.google.com/go/auth"
26 "cloud.google.com/go/auth/credentials"
27 "cloud.google.com/go/auth/internal"
28 "cloud.google.com/go/auth/internal/transport"
29 "cloud.google.com/go/auth/internal/transport/cert"
30 "cloud.google.com/go/auth/internal/transport/headers"
31 "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
32 "golang.org/x/net/http2"
33 )
34
35 const (
36 quotaProjectHeaderKey = "X-goog-user-project"
37 )
38
39 func newTransport(base http.RoundTripper, opts *Options) (http.RoundTripper, error) {
40 var headers = opts.Headers
41 ht := &headerTransport{
42 base: base,
43 headers: headers,
44 }
45 var trans http.RoundTripper = ht
46 trans = addOpenTelemetryTransport(trans, opts)
47 switch {
48 case opts.DisableAuthentication:
49 // Do nothing.
50 case opts.APIKey != "":
51 qp := internal.GetQuotaProject(nil, opts.Headers.Get(quotaProjectHeaderKey))
52 if qp != "" {
53 if headers == nil {
54 headers = make(map[string][]string, 1)
55 }
56 headers.Set(quotaProjectHeaderKey, qp)
57 }
58 trans = &apiKeyTransport{
59 Transport: trans,
60 Key: opts.APIKey,
61 }
62 default:
63 var creds *auth.Credentials
64 if opts.Credentials != nil {
65 creds = opts.Credentials
66 } else {
67 var err error
68 creds, err = credentials.DetectDefault(opts.resolveDetectOptions())
69 if err != nil {
70 return nil, err
71 }
72 }
73 qp, err := creds.QuotaProjectID(context.Background())
74 if err != nil {
75 return nil, err
76 }
77 if qp != "" {
78 if headers == nil {
79 headers = make(map[string][]string, 1)
80 }
81 // Don't overwrite user specified quota
82 if v := headers.Get(quotaProjectHeaderKey); v == "" {
83 headers.Set(quotaProjectHeaderKey, qp)
84 }
85 }
86 var skipUD bool
87 if iOpts := opts.InternalOptions; iOpts != nil {
88 skipUD = iOpts.SkipUniverseDomainValidation
89 }
90 creds.TokenProvider = auth.NewCachedTokenProvider(creds.TokenProvider, nil)
91 trans = &authTransport{
92 base: trans,
93 creds: creds,
94 clientUniverseDomain: opts.UniverseDomain,
95 skipUniverseDomainValidation: skipUD,
96 }
97 }
98 return trans, nil
99 }
100
101 // defaultBaseTransport returns the base HTTP transport.
102 // On App Engine, this is urlfetch.Transport.
103 // Otherwise, use a default transport, taking most defaults from
104 // http.DefaultTransport.
105 // If TLSCertificate is available, set TLSClientConfig as well.
106 func defaultBaseTransport(clientCertSource cert.Provider, dialTLSContext func(context.Context, string, string) (net.Conn, error)) http.RoundTripper {
107 defaultTransport, ok := http.DefaultTransport.(*http.Transport)
108 if !ok {
109 defaultTransport = transport.BaseTransport()
110 }
111 trans := defaultTransport.Clone()
112 trans.MaxIdleConnsPerHost = 100
113
114 if clientCertSource != nil {
115 trans.TLSClientConfig = &tls.Config{
116 GetClientCertificate: clientCertSource,
117 }
118 }
119 if dialTLSContext != nil {
120 // If DialTLSContext is set, TLSClientConfig wil be ignored
121 trans.DialTLSContext = dialTLSContext
122 }
123
124 // Configures the ReadIdleTimeout HTTP/2 option for the
125 // transport. This allows broken idle connections to be pruned more quickly,
126 // preventing the client from attempting to re-use connections that will no
127 // longer work.
128 http2Trans, err := http2.ConfigureTransports(trans)
129 if err == nil {
130 http2Trans.ReadIdleTimeout = time.Second * 31
131 }
132
133 return trans
134 }
135
136 type apiKeyTransport struct {
137 // Key is the API Key to set on requests.
138 Key string
139 // Transport is the underlying HTTP transport.
140 // If nil, http.DefaultTransport is used.
141 Transport http.RoundTripper
142 }
143
144 func (t *apiKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
145 newReq := *req
146 args := newReq.URL.Query()
147 args.Set("key", t.Key)
148 newReq.URL.RawQuery = args.Encode()
149 return t.Transport.RoundTrip(&newReq)
150 }
151
152 type headerTransport struct {
153 headers http.Header
154 base http.RoundTripper
155 }
156
157 func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
158 rt := t.base
159 newReq := *req
160 newReq.Header = make(http.Header)
161 for k, vv := range req.Header {
162 newReq.Header[k] = vv
163 }
164
165 for k, v := range t.headers {
166 newReq.Header[k] = v
167 }
168
169 return rt.RoundTrip(&newReq)
170 }
171
172 func addOpenTelemetryTransport(trans http.RoundTripper, opts *Options) http.RoundTripper {
173 if opts.DisableTelemetry {
174 return trans
175 }
176 return otelhttp.NewTransport(trans)
177 }
178
179 type authTransport struct {
180 creds *auth.Credentials
181 base http.RoundTripper
182 clientUniverseDomain string
183 skipUniverseDomainValidation bool
184 }
185
186 // getClientUniverseDomain returns the default service domain for a given Cloud
187 // universe, with the following precedence:
188 //
189 // 1. A non-empty option.WithUniverseDomain or similar client option.
190 // 2. A non-empty environment variable GOOGLE_CLOUD_UNIVERSE_DOMAIN.
191 // 3. The default value "googleapis.com".
192 //
193 // This is the universe domain configured for the client, which will be compared
194 // to the universe domain that is separately configured for the credentials.
195 func (t *authTransport) getClientUniverseDomain() string {
196 if t.clientUniverseDomain != "" {
197 return t.clientUniverseDomain
198 }
199 if envUD := os.Getenv(internal.UniverseDomainEnvVar); envUD != "" {
200 return envUD
201 }
202 return internal.DefaultUniverseDomain
203 }
204
205 // RoundTrip authorizes and authenticates the request with an
206 // access token from Transport's Source. Per the RoundTripper contract we must
207 // not modify the initial request, so we clone it, and we must close the body
208 // on any errors that happens during our token logic.
209 func (t *authTransport) RoundTrip(req *http.Request) (*http.Response, error) {
210 reqBodyClosed := false
211 if req.Body != nil {
212 defer func() {
213 if !reqBodyClosed {
214 req.Body.Close()
215 }
216 }()
217 }
218 token, err := t.creds.Token(req.Context())
219 if err != nil {
220 return nil, err
221 }
222 if !t.skipUniverseDomainValidation && token.MetadataString("auth.google.tokenSource") != "compute-metadata" {
223 credentialsUniverseDomain, err := t.creds.UniverseDomain(req.Context())
224 if err != nil {
225 return nil, err
226 }
227 if err := transport.ValidateUniverseDomain(t.getClientUniverseDomain(), credentialsUniverseDomain); err != nil {
228 return nil, err
229 }
230 }
231 req2 := req.Clone(req.Context())
232 headers.SetAuthHeader(token, req2)
233 reqBodyClosed = true
234 return t.base.RoundTrip(req2)
235 }
236