s2a.go raw

   1  /*
   2   *
   3   * Copyright 2021 Google LLC
   4   *
   5   * Licensed under the Apache License, Version 2.0 (the "License");
   6   * you may not use this file except in compliance with the License.
   7   * You may obtain a copy of the License at
   8   *
   9   *     https://www.apache.org/licenses/LICENSE-2.0
  10   *
  11   * Unless required by applicable law or agreed to in writing, software
  12   * distributed under the License is distributed on an "AS IS" BASIS,
  13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14   * See the License for the specific language governing permissions and
  15   * limitations under the License.
  16   *
  17   */
  18  
  19  // Package s2a provides the S2A transport credentials used by a gRPC
  20  // application.
  21  package s2a
  22  
  23  import (
  24  	"context"
  25  	"crypto/tls"
  26  	"errors"
  27  	"fmt"
  28  	"net"
  29  	"sync"
  30  	"time"
  31  
  32  	"github.com/google/s2a-go/fallback"
  33  	"github.com/google/s2a-go/internal/handshaker"
  34  	"github.com/google/s2a-go/internal/handshaker/service"
  35  	"github.com/google/s2a-go/internal/tokenmanager"
  36  	"github.com/google/s2a-go/internal/v2"
  37  	"github.com/google/s2a-go/retry"
  38  	"github.com/google/s2a-go/stream"
  39  	"google.golang.org/grpc/credentials"
  40  	"google.golang.org/grpc/grpclog"
  41  	"google.golang.org/protobuf/proto"
  42  
  43  	commonpbv1 "github.com/google/s2a-go/internal/proto/common_go_proto"
  44  	commonpb "github.com/google/s2a-go/internal/proto/v2/common_go_proto"
  45  	s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
  46  )
  47  
  48  const (
  49  	s2aSecurityProtocol = "tls"
  50  	// defaultTimeout specifies the default server handshake timeout.
  51  	defaultTimeout = 30.0 * time.Second
  52  )
  53  
  54  // s2aTransportCreds are the transport credentials required for establishing
  55  // a secure connection using the S2A. They implement the
  56  // credentials.TransportCredentials interface.
  57  type s2aTransportCreds struct {
  58  	info          *credentials.ProtocolInfo
  59  	minTLSVersion commonpbv1.TLSVersion
  60  	maxTLSVersion commonpbv1.TLSVersion
  61  	// tlsCiphersuites contains the ciphersuites used in the S2A connection.
  62  	// Note that these are currently unconfigurable.
  63  	tlsCiphersuites []commonpbv1.Ciphersuite
  64  	// localIdentity should only be used by the client.
  65  	localIdentity *commonpbv1.Identity
  66  	// localIdentities should only be used by the server.
  67  	localIdentities []*commonpbv1.Identity
  68  	// targetIdentities should only be used by the client.
  69  	targetIdentities            []*commonpbv1.Identity
  70  	isClient                    bool
  71  	s2aAddr                     string
  72  	ensureProcessSessionTickets *sync.WaitGroup
  73  }
  74  
  75  // NewClientCreds returns a client-side transport credentials object that uses
  76  // the S2A to establish a secure connection with a server.
  77  func NewClientCreds(opts *ClientOptions) (credentials.TransportCredentials, error) {
  78  	if opts == nil {
  79  		return nil, errors.New("nil client options")
  80  	}
  81  	var targetIdentities []*commonpbv1.Identity
  82  	for _, targetIdentity := range opts.TargetIdentities {
  83  		protoTargetIdentity, err := toProtoIdentity(targetIdentity)
  84  		if err != nil {
  85  			return nil, err
  86  		}
  87  		targetIdentities = append(targetIdentities, protoTargetIdentity)
  88  	}
  89  	localIdentity, err := toProtoIdentity(opts.LocalIdentity)
  90  	if err != nil {
  91  		return nil, err
  92  	}
  93  	if opts.EnableLegacyMode {
  94  		return &s2aTransportCreds{
  95  			info: &credentials.ProtocolInfo{
  96  				SecurityProtocol: s2aSecurityProtocol,
  97  			},
  98  			minTLSVersion: commonpbv1.TLSVersion_TLS1_3,
  99  			maxTLSVersion: commonpbv1.TLSVersion_TLS1_3,
 100  			tlsCiphersuites: []commonpbv1.Ciphersuite{
 101  				commonpbv1.Ciphersuite_AES_128_GCM_SHA256,
 102  				commonpbv1.Ciphersuite_AES_256_GCM_SHA384,
 103  				commonpbv1.Ciphersuite_CHACHA20_POLY1305_SHA256,
 104  			},
 105  			localIdentity:               localIdentity,
 106  			targetIdentities:            targetIdentities,
 107  			isClient:                    true,
 108  			s2aAddr:                     opts.S2AAddress,
 109  			ensureProcessSessionTickets: opts.EnsureProcessSessionTickets,
 110  		}, nil
 111  	}
 112  	verificationMode := getVerificationMode(opts.VerificationMode)
 113  	var fallbackFunc fallback.ClientHandshake
 114  	if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackClientHandshakeFunc != nil {
 115  		fallbackFunc = opts.FallbackOpts.FallbackClientHandshakeFunc
 116  	}
 117  	v2LocalIdentity, err := toV2ProtoIdentity(opts.LocalIdentity)
 118  	if err != nil {
 119  		return nil, err
 120  	}
 121  	return v2.NewClientCreds(opts.S2AAddress, opts.TransportCreds, v2LocalIdentity, verificationMode, fallbackFunc, opts.getS2AStream, opts.serverAuthorizationPolicy)
 122  }
 123  
 124  // NewServerCreds returns a server-side transport credentials object that uses
 125  // the S2A to establish a secure connection with a client.
 126  func NewServerCreds(opts *ServerOptions) (credentials.TransportCredentials, error) {
 127  	if opts == nil {
 128  		return nil, errors.New("nil server options")
 129  	}
 130  	var localIdentities []*commonpbv1.Identity
 131  	for _, localIdentity := range opts.LocalIdentities {
 132  		protoLocalIdentity, err := toProtoIdentity(localIdentity)
 133  		if err != nil {
 134  			return nil, err
 135  		}
 136  		localIdentities = append(localIdentities, protoLocalIdentity)
 137  	}
 138  	if opts.EnableLegacyMode {
 139  		return &s2aTransportCreds{
 140  			info: &credentials.ProtocolInfo{
 141  				SecurityProtocol: s2aSecurityProtocol,
 142  			},
 143  			minTLSVersion: commonpbv1.TLSVersion_TLS1_3,
 144  			maxTLSVersion: commonpbv1.TLSVersion_TLS1_3,
 145  			tlsCiphersuites: []commonpbv1.Ciphersuite{
 146  				commonpbv1.Ciphersuite_AES_128_GCM_SHA256,
 147  				commonpbv1.Ciphersuite_AES_256_GCM_SHA384,
 148  				commonpbv1.Ciphersuite_CHACHA20_POLY1305_SHA256,
 149  			},
 150  			localIdentities: localIdentities,
 151  			isClient:        false,
 152  			s2aAddr:         opts.S2AAddress,
 153  		}, nil
 154  	}
 155  	verificationMode := getVerificationMode(opts.VerificationMode)
 156  	var v2LocalIdentities []*commonpb.Identity
 157  	for _, localIdentity := range opts.LocalIdentities {
 158  		protoLocalIdentity, err := toV2ProtoIdentity(localIdentity)
 159  		if err != nil {
 160  			return nil, err
 161  		}
 162  		v2LocalIdentities = append(v2LocalIdentities, protoLocalIdentity)
 163  	}
 164  	return v2.NewServerCreds(opts.S2AAddress, opts.TransportCreds, v2LocalIdentities, verificationMode, opts.getS2AStream)
 165  }
 166  
 167  // ClientHandshake initiates a client-side TLS handshake using the S2A.
 168  func (c *s2aTransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
 169  	if !c.isClient {
 170  		return nil, nil, errors.New("client handshake called using server transport credentials")
 171  	}
 172  
 173  	var cancel context.CancelFunc
 174  	ctx, cancel = context.WithCancel(ctx)
 175  	defer cancel()
 176  
 177  	// Connect to the S2A.
 178  	hsConn, err := service.Dial(ctx, c.s2aAddr, nil)
 179  	if err != nil {
 180  		grpclog.Infof("Failed to connect to S2A: %v", err)
 181  		return nil, nil, err
 182  	}
 183  
 184  	opts := &handshaker.ClientHandshakerOptions{
 185  		MinTLSVersion:               c.minTLSVersion,
 186  		MaxTLSVersion:               c.maxTLSVersion,
 187  		TLSCiphersuites:             c.tlsCiphersuites,
 188  		TargetIdentities:            c.targetIdentities,
 189  		LocalIdentity:               c.localIdentity,
 190  		TargetName:                  serverAuthority,
 191  		EnsureProcessSessionTickets: c.ensureProcessSessionTickets,
 192  	}
 193  	chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
 194  	if err != nil {
 195  		grpclog.Infof("Call to handshaker.NewClientHandshaker failed: %v", err)
 196  		return nil, nil, err
 197  	}
 198  	defer func() {
 199  		if err != nil {
 200  			if closeErr := chs.Close(); closeErr != nil {
 201  				grpclog.Infof("Close failed unexpectedly: %v", err)
 202  				err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
 203  			}
 204  		}
 205  	}()
 206  
 207  	secConn, authInfo, err := chs.ClientHandshake(context.Background())
 208  	if err != nil {
 209  		grpclog.Infof("Handshake failed: %v", err)
 210  		return nil, nil, err
 211  	}
 212  	return secConn, authInfo, nil
 213  }
 214  
 215  // ServerHandshake initiates a server-side TLS handshake using the S2A.
 216  func (c *s2aTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
 217  	if c.isClient {
 218  		return nil, nil, errors.New("server handshake called using client transport credentials")
 219  	}
 220  
 221  	ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
 222  	defer cancel()
 223  
 224  	// Connect to the S2A.
 225  	hsConn, err := service.Dial(ctx, c.s2aAddr, nil)
 226  	if err != nil {
 227  		grpclog.Infof("Failed to connect to S2A: %v", err)
 228  		return nil, nil, err
 229  	}
 230  
 231  	opts := &handshaker.ServerHandshakerOptions{
 232  		MinTLSVersion:   c.minTLSVersion,
 233  		MaxTLSVersion:   c.maxTLSVersion,
 234  		TLSCiphersuites: c.tlsCiphersuites,
 235  		LocalIdentities: c.localIdentities,
 236  	}
 237  	shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
 238  	if err != nil {
 239  		grpclog.Infof("Call to handshaker.NewServerHandshaker failed: %v", err)
 240  		return nil, nil, err
 241  	}
 242  	defer func() {
 243  		if err != nil {
 244  			if closeErr := shs.Close(); closeErr != nil {
 245  				grpclog.Infof("Close failed unexpectedly: %v", err)
 246  				err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
 247  			}
 248  		}
 249  	}()
 250  
 251  	secConn, authInfo, err := shs.ServerHandshake(context.Background())
 252  	if err != nil {
 253  		grpclog.Infof("Handshake failed: %v", err)
 254  		return nil, nil, err
 255  	}
 256  	return secConn, authInfo, nil
 257  }
 258  
 259  func (c *s2aTransportCreds) Info() credentials.ProtocolInfo {
 260  	return *c.info
 261  }
 262  
 263  func (c *s2aTransportCreds) Clone() credentials.TransportCredentials {
 264  	info := *c.info
 265  	var localIdentity *commonpbv1.Identity
 266  	if c.localIdentity != nil {
 267  		localIdentity = proto.Clone(c.localIdentity).(*commonpbv1.Identity)
 268  	}
 269  	var localIdentities []*commonpbv1.Identity
 270  	if c.localIdentities != nil {
 271  		localIdentities = make([]*commonpbv1.Identity, len(c.localIdentities))
 272  		for i, localIdentity := range c.localIdentities {
 273  			localIdentities[i] = proto.Clone(localIdentity).(*commonpbv1.Identity)
 274  		}
 275  	}
 276  	var targetIdentities []*commonpbv1.Identity
 277  	if c.targetIdentities != nil {
 278  		targetIdentities = make([]*commonpbv1.Identity, len(c.targetIdentities))
 279  		for i, targetIdentity := range c.targetIdentities {
 280  			targetIdentities[i] = proto.Clone(targetIdentity).(*commonpbv1.Identity)
 281  		}
 282  	}
 283  	return &s2aTransportCreds{
 284  		info:             &info,
 285  		minTLSVersion:    c.minTLSVersion,
 286  		maxTLSVersion:    c.maxTLSVersion,
 287  		tlsCiphersuites:  c.tlsCiphersuites,
 288  		localIdentity:    localIdentity,
 289  		localIdentities:  localIdentities,
 290  		targetIdentities: targetIdentities,
 291  		isClient:         c.isClient,
 292  		s2aAddr:          c.s2aAddr,
 293  	}
 294  }
 295  
 296  func (c *s2aTransportCreds) OverrideServerName(serverNameOverride string) error {
 297  	c.info.ServerName = serverNameOverride
 298  	return nil
 299  }
 300  
 301  // TLSClientConfigOptions specifies parameters for creating client TLS config.
 302  type TLSClientConfigOptions struct {
 303  	// ServerName is required by s2a as the expected name when verifying the hostname found in server's certificate.
 304  	// 		tlsConfig, _ := factory.Build(ctx, &s2a.TLSClientConfigOptions{
 305  	//			ServerName: "example.com",
 306  	//		})
 307  	ServerName string
 308  }
 309  
 310  // TLSClientConfigFactory defines the interface for a client TLS config factory.
 311  type TLSClientConfigFactory interface {
 312  	Build(ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error)
 313  }
 314  
 315  // NewTLSClientConfigFactory returns an instance of s2aTLSClientConfigFactory.
 316  func NewTLSClientConfigFactory(opts *ClientOptions) (TLSClientConfigFactory, error) {
 317  	if opts == nil {
 318  		return nil, fmt.Errorf("opts must be non-nil")
 319  	}
 320  	if opts.EnableLegacyMode {
 321  		return nil, fmt.Errorf("NewTLSClientConfigFactory only supports S2Av2")
 322  	}
 323  	tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
 324  	if err != nil {
 325  		// The only possible error is: access token not set in the environment,
 326  		// which is okay in environments other than serverless.
 327  		grpclog.Infof("Access token manager not initialized: %v", err)
 328  		return &s2aTLSClientConfigFactory{
 329  			s2av2Address:              opts.S2AAddress,
 330  			transportCreds:            opts.TransportCreds,
 331  			tokenManager:              nil,
 332  			verificationMode:          getVerificationMode(opts.VerificationMode),
 333  			serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
 334  			getStream:                 opts.getS2AStream,
 335  		}, nil
 336  	}
 337  	return &s2aTLSClientConfigFactory{
 338  		s2av2Address:              opts.S2AAddress,
 339  		transportCreds:            opts.TransportCreds,
 340  		tokenManager:              tokenManager,
 341  		verificationMode:          getVerificationMode(opts.VerificationMode),
 342  		serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
 343  		getStream:                 opts.getS2AStream,
 344  	}, nil
 345  }
 346  
 347  type s2aTLSClientConfigFactory struct {
 348  	s2av2Address              string
 349  	transportCreds            credentials.TransportCredentials
 350  	tokenManager              tokenmanager.AccessTokenManager
 351  	verificationMode          s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
 352  	serverAuthorizationPolicy []byte
 353  	getStream                 stream.GetS2AStream
 354  }
 355  
 356  func (f *s2aTLSClientConfigFactory) Build(
 357  	ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error) {
 358  	serverName := ""
 359  	if opts != nil && opts.ServerName != "" {
 360  		serverName = opts.ServerName
 361  	}
 362  	return v2.NewClientTLSConfig(ctx, f.s2av2Address, f.transportCreds, f.tokenManager, f.verificationMode, serverName, f.serverAuthorizationPolicy, f.getStream)
 363  }
 364  
 365  func getVerificationMode(verificationMode VerificationModeType) s2av2pb.ValidatePeerCertificateChainReq_VerificationMode {
 366  	switch verificationMode {
 367  	case ConnectToGoogle:
 368  		return s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE
 369  	case Spiffe:
 370  		return s2av2pb.ValidatePeerCertificateChainReq_SPIFFE
 371  	case ReservedCustomVerificationMode3:
 372  		return s2av2pb.ValidatePeerCertificateChainReq_RESERVED_CUSTOM_VERIFICATION_MODE_3
 373  	case ReservedCustomVerificationMode4:
 374  		return s2av2pb.ValidatePeerCertificateChainReq_RESERVED_CUSTOM_VERIFICATION_MODE_4
 375  	case ReservedCustomVerificationMode5:
 376  		return s2av2pb.ValidatePeerCertificateChainReq_RESERVED_CUSTOM_VERIFICATION_MODE_5
 377  	case ReservedCustomVerificationMode6:
 378  		return s2av2pb.ValidatePeerCertificateChainReq_RESERVED_CUSTOM_VERIFICATION_MODE_6
 379  	default:
 380  		return s2av2pb.ValidatePeerCertificateChainReq_UNSPECIFIED
 381  	}
 382  }
 383  
 384  // NewS2ADialTLSContextFunc returns a dialer which establishes an MTLS connection using S2A.
 385  // Example use with http.RoundTripper:
 386  //
 387  //		dialTLSContext := s2a.NewS2aDialTLSContextFunc(&s2a.ClientOptions{
 388  //			S2AAddress:         s2aAddress, // required
 389  //		})
 390  //	 	transport := http.DefaultTransport
 391  //	 	transport.DialTLSContext = dialTLSContext
 392  func NewS2ADialTLSContextFunc(opts *ClientOptions) func(ctx context.Context, network, addr string) (net.Conn, error) {
 393  
 394  	return func(ctx context.Context, network, addr string) (net.Conn, error) {
 395  
 396  		fallback := func(err error) (net.Conn, error) {
 397  			if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackDialer != nil &&
 398  				opts.FallbackOpts.FallbackDialer.Dialer != nil && opts.FallbackOpts.FallbackDialer.ServerAddr != "" {
 399  				fbDialer := opts.FallbackOpts.FallbackDialer
 400  				grpclog.Infof("fall back to dial: %s", fbDialer.ServerAddr)
 401  				fbConn, fbErr := fbDialer.Dialer.DialContext(ctx, network, fbDialer.ServerAddr)
 402  				if fbErr != nil {
 403  					return nil, fmt.Errorf("error fallback to %s: %v; S2A error: %w", fbDialer.ServerAddr, fbErr, err)
 404  				}
 405  				return fbConn, nil
 406  			}
 407  			return nil, err
 408  		}
 409  
 410  		factory, err := NewTLSClientConfigFactory(opts)
 411  		if err != nil {
 412  			grpclog.Infof("error creating S2A client config factory: %v", err)
 413  			return fallback(err)
 414  		}
 415  
 416  		serverName, _, err := net.SplitHostPort(addr)
 417  		if err != nil {
 418  			serverName = addr
 419  		}
 420  		timeoutCtx, cancel := context.WithTimeout(ctx, v2.GetS2ATimeout())
 421  		defer cancel()
 422  
 423  		var s2aTLSConfig *tls.Config
 424  		var c net.Conn
 425  		retry.Run(timeoutCtx,
 426  			func() error {
 427  				s2aTLSConfig, err = factory.Build(timeoutCtx, &TLSClientConfigOptions{
 428  					ServerName: serverName,
 429  				})
 430  				if err != nil {
 431  					grpclog.Infof("error building S2A TLS config: %v", err)
 432  					return err
 433  				}
 434  
 435  				s2aDialer := &tls.Dialer{
 436  					Config: s2aTLSConfig,
 437  				}
 438  				c, err = s2aDialer.DialContext(timeoutCtx, network, addr)
 439  				return err
 440  			})
 441  		if err != nil {
 442  			grpclog.Infof("error dialing with S2A to %s: %v", addr, err)
 443  			return fallback(err)
 444  		}
 445  		grpclog.Infof("success dialing MTLS to %s with S2A", addr)
 446  		return c, nil
 447  	}
 448  }
 449