cba.go raw

   1  // Copyright 2023 Google LLC
   2  //
   3  // Licensed under the Apache License, Version 2.0 (the "License");
   4  // you may not use this file except in compliance with the License.
   5  // You may obtain a copy of the License at
   6  //
   7  //      http://www.apache.org/licenses/LICENSE-2.0
   8  //
   9  // Unless required by applicable law or agreed to in writing, software
  10  // distributed under the License is distributed on an "AS IS" BASIS,
  11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12  // See the License for the specific language governing permissions and
  13  // limitations under the License.
  14  
  15  package transport
  16  
  17  import (
  18  	"context"
  19  	"crypto/tls"
  20  	"crypto/x509"
  21  	"errors"
  22  	"log"
  23  	"log/slog"
  24  	"net"
  25  	"net/http"
  26  	"net/url"
  27  	"os"
  28  	"strconv"
  29  	"strings"
  30  
  31  	"cloud.google.com/go/auth/internal"
  32  	"cloud.google.com/go/auth/internal/transport/cert"
  33  	"github.com/google/s2a-go"
  34  	"google.golang.org/grpc/credentials"
  35  )
  36  
  37  const (
  38  	mTLSModeAlways = "always"
  39  	mTLSModeNever  = "never"
  40  	mTLSModeAuto   = "auto"
  41  
  42  	// Experimental: if true, the code will try MTLS with S2A as the default for transport security. Default value is false.
  43  	googleAPIUseS2AEnv     = "EXPERIMENTAL_GOOGLE_API_USE_S2A"
  44  	googleAPIUseCertSource = "GOOGLE_API_USE_CLIENT_CERTIFICATE"
  45  	googleAPIUseMTLS       = "GOOGLE_API_USE_MTLS_ENDPOINT"
  46  	googleAPIUseMTLSOld    = "GOOGLE_API_USE_MTLS"
  47  
  48  	universeDomainPlaceholder = "UNIVERSE_DOMAIN"
  49  
  50  	mtlsMDSRoot = "/run/google-mds-mtls/root.crt"
  51  	mtlsMDSKey  = "/run/google-mds-mtls/client.key"
  52  )
  53  
  54  // Type represents the type of transport used.
  55  type Type int
  56  
  57  const (
  58  	// TransportTypeUnknown represents an unknown transport type and is the default option.
  59  	TransportTypeUnknown Type = iota
  60  	// TransportTypeMTLSS2A represents the mTLS transport type using S2A.
  61  	TransportTypeMTLSS2A
  62  )
  63  
  64  // Options is a struct that is duplicated information from the individual
  65  // transport packages in order to avoid cyclic deps. It correlates 1:1 with
  66  // fields on httptransport.Options and grpctransport.Options.
  67  type Options struct {
  68  	Endpoint                string
  69  	DefaultEndpointTemplate string
  70  	DefaultMTLSEndpoint     string
  71  	ClientCertProvider      cert.Provider
  72  	Client                  *http.Client
  73  	UniverseDomain          string
  74  	EnableDirectPath        bool
  75  	EnableDirectPathXds     bool
  76  	Logger                  *slog.Logger
  77  }
  78  
  79  // getUniverseDomain returns the default service domain for a given Cloud
  80  // universe.
  81  func (o *Options) getUniverseDomain() string {
  82  	if o.UniverseDomain == "" {
  83  		return internal.DefaultUniverseDomain
  84  	}
  85  	return o.UniverseDomain
  86  }
  87  
  88  // isUniverseDomainGDU returns true if the universe domain is the default Google
  89  // universe.
  90  func (o *Options) isUniverseDomainGDU() bool {
  91  	return o.getUniverseDomain() == internal.DefaultUniverseDomain
  92  }
  93  
  94  // defaultEndpoint returns the DefaultEndpointTemplate merged with the
  95  // universe domain if the DefaultEndpointTemplate is set, otherwise returns an
  96  // empty string.
  97  func (o *Options) defaultEndpoint() string {
  98  	if o.DefaultEndpointTemplate == "" {
  99  		return ""
 100  	}
 101  	return strings.Replace(o.DefaultEndpointTemplate, universeDomainPlaceholder, o.getUniverseDomain(), 1)
 102  }
 103  
 104  // defaultMTLSEndpoint returns the DefaultMTLSEndpointTemplate merged with the
 105  // universe domain if the DefaultMTLSEndpointTemplate is set, otherwise returns an
 106  // empty string.
 107  func (o *Options) defaultMTLSEndpoint() string {
 108  	if o.DefaultMTLSEndpoint == "" {
 109  		return ""
 110  	}
 111  	return strings.Replace(o.DefaultMTLSEndpoint, universeDomainPlaceholder, o.getUniverseDomain(), 1)
 112  }
 113  
 114  // mergedEndpoint merges a user-provided Endpoint of format host[:port] with the
 115  // default endpoint.
 116  func (o *Options) mergedEndpoint() (string, error) {
 117  	defaultEndpoint := o.defaultEndpoint()
 118  	u, err := url.Parse(fixScheme(defaultEndpoint))
 119  	if err != nil {
 120  		return "", err
 121  	}
 122  	return strings.Replace(defaultEndpoint, u.Host, o.Endpoint, 1), nil
 123  }
 124  
 125  func fixScheme(baseURL string) string {
 126  	if !strings.Contains(baseURL, "://") {
 127  		baseURL = "https://" + baseURL
 128  	}
 129  	return baseURL
 130  }
 131  
 132  // GRPCTransportCredentials embeds interface TransportCredentials with additional data.
 133  type GRPCTransportCredentials struct {
 134  	credentials.TransportCredentials
 135  	Endpoint      string
 136  	TransportType Type
 137  }
 138  
 139  // GetGRPCTransportCredsAndEndpoint returns an instance of
 140  // [google.golang.org/grpc/credentials.TransportCredentials], and the
 141  // corresponding endpoint and transport type to use for GRPC client.
 142  func GetGRPCTransportCredsAndEndpoint(opts *Options) (*GRPCTransportCredentials, error) {
 143  	config, err := getTransportConfig(opts)
 144  	if err != nil {
 145  		return nil, err
 146  	}
 147  
 148  	defaultTransportCreds := credentials.NewTLS(&tls.Config{
 149  		GetClientCertificate: config.clientCertSource,
 150  	})
 151  
 152  	var s2aAddr string
 153  	var transportCredsForS2A credentials.TransportCredentials
 154  
 155  	if config.mtlsS2AAddress != "" {
 156  		s2aAddr = config.mtlsS2AAddress
 157  		transportCredsForS2A, err = loadMTLSMDSTransportCreds(mtlsMDSRoot, mtlsMDSKey)
 158  		if err != nil {
 159  			log.Printf("Loading MTLS MDS credentials failed: %v", err)
 160  			if config.s2aAddress != "" {
 161  				s2aAddr = config.s2aAddress
 162  			} else {
 163  				return &GRPCTransportCredentials{defaultTransportCreds, config.endpoint, TransportTypeUnknown}, nil
 164  			}
 165  		}
 166  	} else if config.s2aAddress != "" {
 167  		s2aAddr = config.s2aAddress
 168  	} else {
 169  		return &GRPCTransportCredentials{defaultTransportCreds, config.endpoint, TransportTypeUnknown}, nil
 170  	}
 171  
 172  	s2aTransportCreds, err := s2a.NewClientCreds(&s2a.ClientOptions{
 173  		S2AAddress:     s2aAddr,
 174  		TransportCreds: transportCredsForS2A,
 175  	})
 176  	if err != nil {
 177  		// Use default if we cannot initialize S2A client transport credentials.
 178  		return &GRPCTransportCredentials{defaultTransportCreds, config.endpoint, TransportTypeUnknown}, nil
 179  	}
 180  	return &GRPCTransportCredentials{s2aTransportCreds, config.s2aMTLSEndpoint, TransportTypeMTLSS2A}, nil
 181  }
 182  
 183  // GetHTTPTransportConfig returns a client certificate source and a function for
 184  // dialing MTLS with S2A.
 185  func GetHTTPTransportConfig(opts *Options) (cert.Provider, func(context.Context, string, string) (net.Conn, error), error) {
 186  	config, err := getTransportConfig(opts)
 187  	if err != nil {
 188  		return nil, nil, err
 189  	}
 190  
 191  	var s2aAddr string
 192  	var transportCredsForS2A credentials.TransportCredentials
 193  
 194  	if config.mtlsS2AAddress != "" {
 195  		s2aAddr = config.mtlsS2AAddress
 196  		transportCredsForS2A, err = loadMTLSMDSTransportCreds(mtlsMDSRoot, mtlsMDSKey)
 197  		if err != nil {
 198  			log.Printf("Loading MTLS MDS credentials failed: %v", err)
 199  			if config.s2aAddress != "" {
 200  				s2aAddr = config.s2aAddress
 201  			} else {
 202  				return config.clientCertSource, nil, nil
 203  			}
 204  		}
 205  	} else if config.s2aAddress != "" {
 206  		s2aAddr = config.s2aAddress
 207  	} else {
 208  		return config.clientCertSource, nil, nil
 209  	}
 210  
 211  	dialTLSContextFunc := s2a.NewS2ADialTLSContextFunc(&s2a.ClientOptions{
 212  		S2AAddress:     s2aAddr,
 213  		TransportCreds: transportCredsForS2A,
 214  	})
 215  	return nil, dialTLSContextFunc, nil
 216  }
 217  
 218  func loadMTLSMDSTransportCreds(mtlsMDSRootFile, mtlsMDSKeyFile string) (credentials.TransportCredentials, error) {
 219  	rootPEM, err := os.ReadFile(mtlsMDSRootFile)
 220  	if err != nil {
 221  		return nil, err
 222  	}
 223  	caCertPool := x509.NewCertPool()
 224  	ok := caCertPool.AppendCertsFromPEM(rootPEM)
 225  	if !ok {
 226  		return nil, errors.New("failed to load MTLS MDS root certificate")
 227  	}
 228  	// The mTLS MDS credentials are formatted as the concatenation of a PEM-encoded certificate chain
 229  	// followed by a PEM-encoded private key. For this reason, the concatenation is passed in to the
 230  	// tls.X509KeyPair function as both the certificate chain and private key arguments.
 231  	cert, err := tls.LoadX509KeyPair(mtlsMDSKeyFile, mtlsMDSKeyFile)
 232  	if err != nil {
 233  		return nil, err
 234  	}
 235  	tlsConfig := tls.Config{
 236  		RootCAs:      caCertPool,
 237  		Certificates: []tls.Certificate{cert},
 238  		MinVersion:   tls.VersionTLS13,
 239  	}
 240  	return credentials.NewTLS(&tlsConfig), nil
 241  }
 242  
 243  func getTransportConfig(opts *Options) (*transportConfig, error) {
 244  	clientCertSource, err := GetClientCertificateProvider(opts)
 245  	if err != nil {
 246  		return nil, err
 247  	}
 248  	endpoint, err := getEndpoint(opts, clientCertSource)
 249  	if err != nil {
 250  		return nil, err
 251  	}
 252  	defaultTransportConfig := transportConfig{
 253  		clientCertSource: clientCertSource,
 254  		endpoint:         endpoint,
 255  	}
 256  
 257  	if !shouldUseS2A(clientCertSource, opts) {
 258  		return &defaultTransportConfig, nil
 259  	}
 260  
 261  	s2aAddress := GetS2AAddress(opts.Logger)
 262  	mtlsS2AAddress := GetMTLSS2AAddress(opts.Logger)
 263  	if s2aAddress == "" && mtlsS2AAddress == "" {
 264  		return &defaultTransportConfig, nil
 265  	}
 266  	return &transportConfig{
 267  		clientCertSource: clientCertSource,
 268  		endpoint:         endpoint,
 269  		s2aAddress:       s2aAddress,
 270  		mtlsS2AAddress:   mtlsS2AAddress,
 271  		s2aMTLSEndpoint:  opts.defaultMTLSEndpoint(),
 272  	}, nil
 273  }
 274  
 275  // GetClientCertificateProvider returns a default client certificate source, if
 276  // not provided by the user.
 277  //
 278  // A nil default source can be returned if the source does not exist. Any exceptions
 279  // encountered while initializing the default source will be reported as client
 280  // error (ex. corrupt metadata file).
 281  func GetClientCertificateProvider(opts *Options) (cert.Provider, error) {
 282  	if !isClientCertificateEnabled(opts) {
 283  		return nil, nil
 284  	} else if opts.ClientCertProvider != nil {
 285  		return opts.ClientCertProvider, nil
 286  	}
 287  	return cert.DefaultProvider()
 288  
 289  }
 290  
 291  // isClientCertificateEnabled returns true by default for all GDU universe domain, unless explicitly overridden by env var
 292  func isClientCertificateEnabled(opts *Options) bool {
 293  	if value, ok := os.LookupEnv(googleAPIUseCertSource); ok {
 294  		// error as false is OK
 295  		b, _ := strconv.ParseBool(value)
 296  		return b
 297  	}
 298  	return opts.isUniverseDomainGDU()
 299  }
 300  
 301  type transportConfig struct {
 302  	// The client certificate source.
 303  	clientCertSource cert.Provider
 304  	// The corresponding endpoint to use based on client certificate source.
 305  	endpoint string
 306  	// The plaintext S2A address if it can be used, otherwise an empty string.
 307  	s2aAddress string
 308  	// The MTLS S2A address if it can be used, otherwise an empty string.
 309  	mtlsS2AAddress string
 310  	// The MTLS endpoint to use with S2A.
 311  	s2aMTLSEndpoint string
 312  }
 313  
 314  // getEndpoint returns the endpoint for the service, taking into account the
 315  // user-provided endpoint override "settings.Endpoint".
 316  //
 317  // If no endpoint override is specified, we will either return the default
 318  // endpoint or the default mTLS endpoint if a client certificate is available.
 319  //
 320  // You can override the default endpoint choice (mTLS vs. regular) by setting
 321  // the GOOGLE_API_USE_MTLS_ENDPOINT environment variable.
 322  //
 323  // If the endpoint override is an address (host:port) rather than full base
 324  // URL (ex. https://...), then the user-provided address will be merged into
 325  // the default endpoint. For example, WithEndpoint("myhost:8000") and
 326  // DefaultEndpointTemplate("https://UNIVERSE_DOMAIN/bar/baz") will return
 327  // "https://myhost:8080/bar/baz". Note that this does not apply to the mTLS
 328  // endpoint.
 329  func getEndpoint(opts *Options, clientCertSource cert.Provider) (string, error) {
 330  	if opts.Endpoint == "" {
 331  		mtlsMode := getMTLSMode()
 332  		if mtlsMode == mTLSModeAlways || (clientCertSource != nil && mtlsMode == mTLSModeAuto) {
 333  			return opts.defaultMTLSEndpoint(), nil
 334  		}
 335  		return opts.defaultEndpoint(), nil
 336  	}
 337  	if strings.Contains(opts.Endpoint, "://") {
 338  		// User passed in a full URL path, use it verbatim.
 339  		return opts.Endpoint, nil
 340  	}
 341  	if opts.defaultEndpoint() == "" {
 342  		// If DefaultEndpointTemplate is not configured,
 343  		// use the user provided endpoint verbatim. This allows a naked
 344  		// "host[:port]" URL to be used with GRPC Direct Path.
 345  		return opts.Endpoint, nil
 346  	}
 347  
 348  	// Assume user-provided endpoint is host[:port], merge it with the default endpoint.
 349  	return opts.mergedEndpoint()
 350  }
 351  
 352  func getMTLSMode() string {
 353  	mode := os.Getenv(googleAPIUseMTLS)
 354  	if mode == "" {
 355  		mode = os.Getenv(googleAPIUseMTLSOld) // Deprecated.
 356  	}
 357  	if mode == "" {
 358  		return mTLSModeAuto
 359  	}
 360  	return strings.ToLower(mode)
 361  }
 362