cba.go raw
1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package transport
16
17 import (
18 "context"
19 "crypto/tls"
20 "crypto/x509"
21 "errors"
22 "log"
23 "log/slog"
24 "net"
25 "net/http"
26 "net/url"
27 "os"
28 "strconv"
29 "strings"
30
31 "cloud.google.com/go/auth/internal"
32 "cloud.google.com/go/auth/internal/transport/cert"
33 "github.com/google/s2a-go"
34 "google.golang.org/grpc/credentials"
35 )
36
37 const (
38 mTLSModeAlways = "always"
39 mTLSModeNever = "never"
40 mTLSModeAuto = "auto"
41
42 // Experimental: if true, the code will try MTLS with S2A as the default for transport security. Default value is false.
43 googleAPIUseS2AEnv = "EXPERIMENTAL_GOOGLE_API_USE_S2A"
44 googleAPIUseCertSource = "GOOGLE_API_USE_CLIENT_CERTIFICATE"
45 googleAPIUseMTLS = "GOOGLE_API_USE_MTLS_ENDPOINT"
46 googleAPIUseMTLSOld = "GOOGLE_API_USE_MTLS"
47
48 universeDomainPlaceholder = "UNIVERSE_DOMAIN"
49
50 mtlsMDSRoot = "/run/google-mds-mtls/root.crt"
51 mtlsMDSKey = "/run/google-mds-mtls/client.key"
52 )
53
54 // Type represents the type of transport used.
55 type Type int
56
57 const (
58 // TransportTypeUnknown represents an unknown transport type and is the default option.
59 TransportTypeUnknown Type = iota
60 // TransportTypeMTLSS2A represents the mTLS transport type using S2A.
61 TransportTypeMTLSS2A
62 )
63
64 // Options is a struct that is duplicated information from the individual
65 // transport packages in order to avoid cyclic deps. It correlates 1:1 with
66 // fields on httptransport.Options and grpctransport.Options.
67 type Options struct {
68 Endpoint string
69 DefaultEndpointTemplate string
70 DefaultMTLSEndpoint string
71 ClientCertProvider cert.Provider
72 Client *http.Client
73 UniverseDomain string
74 EnableDirectPath bool
75 EnableDirectPathXds bool
76 Logger *slog.Logger
77 }
78
79 // getUniverseDomain returns the default service domain for a given Cloud
80 // universe.
81 func (o *Options) getUniverseDomain() string {
82 if o.UniverseDomain == "" {
83 return internal.DefaultUniverseDomain
84 }
85 return o.UniverseDomain
86 }
87
88 // isUniverseDomainGDU returns true if the universe domain is the default Google
89 // universe.
90 func (o *Options) isUniverseDomainGDU() bool {
91 return o.getUniverseDomain() == internal.DefaultUniverseDomain
92 }
93
94 // defaultEndpoint returns the DefaultEndpointTemplate merged with the
95 // universe domain if the DefaultEndpointTemplate is set, otherwise returns an
96 // empty string.
97 func (o *Options) defaultEndpoint() string {
98 if o.DefaultEndpointTemplate == "" {
99 return ""
100 }
101 return strings.Replace(o.DefaultEndpointTemplate, universeDomainPlaceholder, o.getUniverseDomain(), 1)
102 }
103
104 // defaultMTLSEndpoint returns the DefaultMTLSEndpointTemplate merged with the
105 // universe domain if the DefaultMTLSEndpointTemplate is set, otherwise returns an
106 // empty string.
107 func (o *Options) defaultMTLSEndpoint() string {
108 if o.DefaultMTLSEndpoint == "" {
109 return ""
110 }
111 return strings.Replace(o.DefaultMTLSEndpoint, universeDomainPlaceholder, o.getUniverseDomain(), 1)
112 }
113
114 // mergedEndpoint merges a user-provided Endpoint of format host[:port] with the
115 // default endpoint.
116 func (o *Options) mergedEndpoint() (string, error) {
117 defaultEndpoint := o.defaultEndpoint()
118 u, err := url.Parse(fixScheme(defaultEndpoint))
119 if err != nil {
120 return "", err
121 }
122 return strings.Replace(defaultEndpoint, u.Host, o.Endpoint, 1), nil
123 }
124
125 func fixScheme(baseURL string) string {
126 if !strings.Contains(baseURL, "://") {
127 baseURL = "https://" + baseURL
128 }
129 return baseURL
130 }
131
132 // GRPCTransportCredentials embeds interface TransportCredentials with additional data.
133 type GRPCTransportCredentials struct {
134 credentials.TransportCredentials
135 Endpoint string
136 TransportType Type
137 }
138
139 // GetGRPCTransportCredsAndEndpoint returns an instance of
140 // [google.golang.org/grpc/credentials.TransportCredentials], and the
141 // corresponding endpoint and transport type to use for GRPC client.
142 func GetGRPCTransportCredsAndEndpoint(opts *Options) (*GRPCTransportCredentials, error) {
143 config, err := getTransportConfig(opts)
144 if err != nil {
145 return nil, err
146 }
147
148 defaultTransportCreds := credentials.NewTLS(&tls.Config{
149 GetClientCertificate: config.clientCertSource,
150 })
151
152 var s2aAddr string
153 var transportCredsForS2A credentials.TransportCredentials
154
155 if config.mtlsS2AAddress != "" {
156 s2aAddr = config.mtlsS2AAddress
157 transportCredsForS2A, err = loadMTLSMDSTransportCreds(mtlsMDSRoot, mtlsMDSKey)
158 if err != nil {
159 log.Printf("Loading MTLS MDS credentials failed: %v", err)
160 if config.s2aAddress != "" {
161 s2aAddr = config.s2aAddress
162 } else {
163 return &GRPCTransportCredentials{defaultTransportCreds, config.endpoint, TransportTypeUnknown}, nil
164 }
165 }
166 } else if config.s2aAddress != "" {
167 s2aAddr = config.s2aAddress
168 } else {
169 return &GRPCTransportCredentials{defaultTransportCreds, config.endpoint, TransportTypeUnknown}, nil
170 }
171
172 s2aTransportCreds, err := s2a.NewClientCreds(&s2a.ClientOptions{
173 S2AAddress: s2aAddr,
174 TransportCreds: transportCredsForS2A,
175 })
176 if err != nil {
177 // Use default if we cannot initialize S2A client transport credentials.
178 return &GRPCTransportCredentials{defaultTransportCreds, config.endpoint, TransportTypeUnknown}, nil
179 }
180 return &GRPCTransportCredentials{s2aTransportCreds, config.s2aMTLSEndpoint, TransportTypeMTLSS2A}, nil
181 }
182
183 // GetHTTPTransportConfig returns a client certificate source and a function for
184 // dialing MTLS with S2A.
185 func GetHTTPTransportConfig(opts *Options) (cert.Provider, func(context.Context, string, string) (net.Conn, error), error) {
186 config, err := getTransportConfig(opts)
187 if err != nil {
188 return nil, nil, err
189 }
190
191 var s2aAddr string
192 var transportCredsForS2A credentials.TransportCredentials
193
194 if config.mtlsS2AAddress != "" {
195 s2aAddr = config.mtlsS2AAddress
196 transportCredsForS2A, err = loadMTLSMDSTransportCreds(mtlsMDSRoot, mtlsMDSKey)
197 if err != nil {
198 log.Printf("Loading MTLS MDS credentials failed: %v", err)
199 if config.s2aAddress != "" {
200 s2aAddr = config.s2aAddress
201 } else {
202 return config.clientCertSource, nil, nil
203 }
204 }
205 } else if config.s2aAddress != "" {
206 s2aAddr = config.s2aAddress
207 } else {
208 return config.clientCertSource, nil, nil
209 }
210
211 dialTLSContextFunc := s2a.NewS2ADialTLSContextFunc(&s2a.ClientOptions{
212 S2AAddress: s2aAddr,
213 TransportCreds: transportCredsForS2A,
214 })
215 return nil, dialTLSContextFunc, nil
216 }
217
218 func loadMTLSMDSTransportCreds(mtlsMDSRootFile, mtlsMDSKeyFile string) (credentials.TransportCredentials, error) {
219 rootPEM, err := os.ReadFile(mtlsMDSRootFile)
220 if err != nil {
221 return nil, err
222 }
223 caCertPool := x509.NewCertPool()
224 ok := caCertPool.AppendCertsFromPEM(rootPEM)
225 if !ok {
226 return nil, errors.New("failed to load MTLS MDS root certificate")
227 }
228 // The mTLS MDS credentials are formatted as the concatenation of a PEM-encoded certificate chain
229 // followed by a PEM-encoded private key. For this reason, the concatenation is passed in to the
230 // tls.X509KeyPair function as both the certificate chain and private key arguments.
231 cert, err := tls.LoadX509KeyPair(mtlsMDSKeyFile, mtlsMDSKeyFile)
232 if err != nil {
233 return nil, err
234 }
235 tlsConfig := tls.Config{
236 RootCAs: caCertPool,
237 Certificates: []tls.Certificate{cert},
238 MinVersion: tls.VersionTLS13,
239 }
240 return credentials.NewTLS(&tlsConfig), nil
241 }
242
243 func getTransportConfig(opts *Options) (*transportConfig, error) {
244 clientCertSource, err := GetClientCertificateProvider(opts)
245 if err != nil {
246 return nil, err
247 }
248 endpoint, err := getEndpoint(opts, clientCertSource)
249 if err != nil {
250 return nil, err
251 }
252 defaultTransportConfig := transportConfig{
253 clientCertSource: clientCertSource,
254 endpoint: endpoint,
255 }
256
257 if !shouldUseS2A(clientCertSource, opts) {
258 return &defaultTransportConfig, nil
259 }
260
261 s2aAddress := GetS2AAddress(opts.Logger)
262 mtlsS2AAddress := GetMTLSS2AAddress(opts.Logger)
263 if s2aAddress == "" && mtlsS2AAddress == "" {
264 return &defaultTransportConfig, nil
265 }
266 return &transportConfig{
267 clientCertSource: clientCertSource,
268 endpoint: endpoint,
269 s2aAddress: s2aAddress,
270 mtlsS2AAddress: mtlsS2AAddress,
271 s2aMTLSEndpoint: opts.defaultMTLSEndpoint(),
272 }, nil
273 }
274
275 // GetClientCertificateProvider returns a default client certificate source, if
276 // not provided by the user.
277 //
278 // A nil default source can be returned if the source does not exist. Any exceptions
279 // encountered while initializing the default source will be reported as client
280 // error (ex. corrupt metadata file).
281 func GetClientCertificateProvider(opts *Options) (cert.Provider, error) {
282 if !isClientCertificateEnabled(opts) {
283 return nil, nil
284 } else if opts.ClientCertProvider != nil {
285 return opts.ClientCertProvider, nil
286 }
287 return cert.DefaultProvider()
288
289 }
290
291 // isClientCertificateEnabled returns true by default for all GDU universe domain, unless explicitly overridden by env var
292 func isClientCertificateEnabled(opts *Options) bool {
293 if value, ok := os.LookupEnv(googleAPIUseCertSource); ok {
294 // error as false is OK
295 b, _ := strconv.ParseBool(value)
296 return b
297 }
298 return opts.isUniverseDomainGDU()
299 }
300
301 type transportConfig struct {
302 // The client certificate source.
303 clientCertSource cert.Provider
304 // The corresponding endpoint to use based on client certificate source.
305 endpoint string
306 // The plaintext S2A address if it can be used, otherwise an empty string.
307 s2aAddress string
308 // The MTLS S2A address if it can be used, otherwise an empty string.
309 mtlsS2AAddress string
310 // The MTLS endpoint to use with S2A.
311 s2aMTLSEndpoint string
312 }
313
314 // getEndpoint returns the endpoint for the service, taking into account the
315 // user-provided endpoint override "settings.Endpoint".
316 //
317 // If no endpoint override is specified, we will either return the default
318 // endpoint or the default mTLS endpoint if a client certificate is available.
319 //
320 // You can override the default endpoint choice (mTLS vs. regular) by setting
321 // the GOOGLE_API_USE_MTLS_ENDPOINT environment variable.
322 //
323 // If the endpoint override is an address (host:port) rather than full base
324 // URL (ex. https://...), then the user-provided address will be merged into
325 // the default endpoint. For example, WithEndpoint("myhost:8000") and
326 // DefaultEndpointTemplate("https://UNIVERSE_DOMAIN/bar/baz") will return
327 // "https://myhost:8080/bar/baz". Note that this does not apply to the mTLS
328 // endpoint.
329 func getEndpoint(opts *Options, clientCertSource cert.Provider) (string, error) {
330 if opts.Endpoint == "" {
331 mtlsMode := getMTLSMode()
332 if mtlsMode == mTLSModeAlways || (clientCertSource != nil && mtlsMode == mTLSModeAuto) {
333 return opts.defaultMTLSEndpoint(), nil
334 }
335 return opts.defaultEndpoint(), nil
336 }
337 if strings.Contains(opts.Endpoint, "://") {
338 // User passed in a full URL path, use it verbatim.
339 return opts.Endpoint, nil
340 }
341 if opts.defaultEndpoint() == "" {
342 // If DefaultEndpointTemplate is not configured,
343 // use the user provided endpoint verbatim. This allows a naked
344 // "host[:port]" URL to be used with GRPC Direct Path.
345 return opts.Endpoint, nil
346 }
347
348 // Assume user-provided endpoint is host[:port], merge it with the default endpoint.
349 return opts.mergedEndpoint()
350 }
351
352 func getMTLSMode() string {
353 mode := os.Getenv(googleAPIUseMTLS)
354 if mode == "" {
355 mode = os.Getenv(googleAPIUseMTLSOld) // Deprecated.
356 }
357 if mode == "" {
358 return mTLSModeAuto
359 }
360 return strings.ToLower(mode)
361 }
362