tlsconfigstore.go raw

   1  /*
   2   *
   3   * Copyright 2022 Google LLC
   4   *
   5   * Licensed under the Apache License, Version 2.0 (the "License");
   6   * you may not use this file except in compliance with the License.
   7   * You may obtain a copy of the License at
   8   *
   9   *     https://www.apache.org/licenses/LICENSE-2.0
  10   *
  11   * Unless required by applicable law or agreed to in writing, software
  12   * distributed under the License is distributed on an "AS IS" BASIS,
  13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14   * See the License for the specific language governing permissions and
  15   * limitations under the License.
  16   *
  17   */
  18  
  19  // Package tlsconfigstore offloads operations to S2Av2.
  20  package tlsconfigstore
  21  
  22  import (
  23  	"crypto/tls"
  24  	"crypto/x509"
  25  	"encoding/pem"
  26  	"errors"
  27  	"fmt"
  28  
  29  	"github.com/google/s2a-go/internal/tokenmanager"
  30  	"github.com/google/s2a-go/internal/v2/certverifier"
  31  	"github.com/google/s2a-go/internal/v2/remotesigner"
  32  	"github.com/google/s2a-go/stream"
  33  	"google.golang.org/grpc/codes"
  34  	"google.golang.org/grpc/grpclog"
  35  
  36  	commonpb "github.com/google/s2a-go/internal/proto/v2/common_go_proto"
  37  	s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
  38  )
  39  
  40  const (
  41  	// HTTP/2
  42  	h2 = "h2"
  43  )
  44  
  45  // GetTLSConfigurationForClient returns a tls.Config instance for use by a client application.
  46  func GetTLSConfigurationForClient(serverHostname string, s2AStream stream.S2AStream, tokenManager tokenmanager.AccessTokenManager, localIdentity *commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, serverAuthorizationPolicy []byte) (*tls.Config, error) {
  47  	authMechanisms := getAuthMechanisms(tokenManager, []*commonpb.Identity{localIdentity})
  48  
  49  	if grpclog.V(1) {
  50  		grpclog.Infof("Sending request to S2Av2 for client TLS config.")
  51  	}
  52  	// Send request to S2Av2 for config.
  53  	if err := s2AStream.Send(&s2av2pb.SessionReq{
  54  		LocalIdentity:            localIdentity,
  55  		AuthenticationMechanisms: authMechanisms,
  56  		ReqOneof: &s2av2pb.SessionReq_GetTlsConfigurationReq{
  57  			GetTlsConfigurationReq: &s2av2pb.GetTlsConfigurationReq{
  58  				ConnectionSide: commonpb.ConnectionSide_CONNECTION_SIDE_CLIENT,
  59  			},
  60  		},
  61  	}); err != nil {
  62  		grpclog.Infof("Failed to send request to S2Av2 for client TLS config")
  63  		return nil, err
  64  	}
  65  
  66  	// Get the response containing config from S2Av2.
  67  	resp, err := s2AStream.Recv()
  68  	if err != nil {
  69  		grpclog.Infof("Failed to receive client TLS config response from S2Av2.")
  70  		return nil, err
  71  	}
  72  
  73  	// TODO(rmehta19): Add unit test for this if statement.
  74  	if (resp.GetStatus() != nil) && (resp.GetStatus().Code != uint32(codes.OK)) {
  75  		return nil, fmt.Errorf("failed to get TLS configuration from S2A: %d, %v", resp.GetStatus().Code, resp.GetStatus().Details)
  76  	}
  77  
  78  	// Extract TLS configuration from SessionResp.
  79  	tlsConfig := resp.GetGetTlsConfigurationResp().GetClientTlsConfiguration()
  80  
  81  	var cert tls.Certificate
  82  	for i, v := range tlsConfig.CertificateChain {
  83  		// Populate Certificates field.
  84  		block, _ := pem.Decode([]byte(v))
  85  		if block == nil {
  86  			return nil, errors.New("certificate in CertificateChain obtained from S2Av2 is empty")
  87  		}
  88  		x509Cert, err := x509.ParseCertificate(block.Bytes)
  89  		if err != nil {
  90  			return nil, err
  91  		}
  92  		cert.Certificate = append(cert.Certificate, x509Cert.Raw)
  93  		if i == 0 {
  94  			cert.Leaf = x509Cert
  95  		}
  96  	}
  97  
  98  	if len(tlsConfig.CertificateChain) > 0 {
  99  		cert.PrivateKey = remotesigner.New(cert.Leaf, s2AStream)
 100  		if cert.PrivateKey == nil {
 101  			return nil, errors.New("failed to retrieve Private Key from Remote Signer Library")
 102  		}
 103  	}
 104  
 105  	minVersion, maxVersion, err := getTLSMinMaxVersionsClient(tlsConfig)
 106  	if err != nil {
 107  		return nil, err
 108  	}
 109  
 110  	// Create mTLS credentials for client.
 111  	config := &tls.Config{
 112  		VerifyPeerCertificate:  certverifier.VerifyServerCertificateChain(serverHostname, verificationMode, s2AStream, serverAuthorizationPolicy),
 113  		ServerName:             serverHostname,
 114  		InsecureSkipVerify:     true, // NOLINT
 115  		ClientSessionCache:     nil,
 116  		SessionTicketsDisabled: true,
 117  		MinVersion:             minVersion,
 118  		MaxVersion:             maxVersion,
 119  		NextProtos:             []string{h2},
 120  	}
 121  	if len(tlsConfig.CertificateChain) > 0 {
 122  		config.Certificates = []tls.Certificate{cert}
 123  	}
 124  	return config, nil
 125  }
 126  
 127  // GetTLSConfigurationForServer returns a tls.Config instance for use by a server application.
 128  func GetTLSConfigurationForServer(s2AStream stream.S2AStream, tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode) (*tls.Config, error) {
 129  	return &tls.Config{
 130  		GetConfigForClient: ClientConfig(tokenManager, localIdentities, verificationMode, s2AStream),
 131  	}, nil
 132  }
 133  
 134  // ClientConfig builds a TLS config for a server to establish a secure
 135  // connection with a client, based on SNI communicated during ClientHello.
 136  // Ensures that server presents the correct certificate to establish a TLS
 137  // connection.
 138  func ClientConfig(tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, s2AStream stream.S2AStream) func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
 139  	return func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
 140  		tlsConfig, err := getServerConfigFromS2Av2(tokenManager, localIdentities, chi.ServerName, s2AStream)
 141  		if err != nil {
 142  			return nil, err
 143  		}
 144  
 145  		var cert tls.Certificate
 146  		for i, v := range tlsConfig.CertificateChain {
 147  			// Populate Certificates field.
 148  			block, _ := pem.Decode([]byte(v))
 149  			if block == nil {
 150  				return nil, errors.New("certificate in CertificateChain obtained from S2Av2 is empty")
 151  			}
 152  			x509Cert, err := x509.ParseCertificate(block.Bytes)
 153  			if err != nil {
 154  				return nil, err
 155  			}
 156  			cert.Certificate = append(cert.Certificate, x509Cert.Raw)
 157  			if i == 0 {
 158  				cert.Leaf = x509Cert
 159  			}
 160  		}
 161  
 162  		cert.PrivateKey = remotesigner.New(cert.Leaf, s2AStream)
 163  		if cert.PrivateKey == nil {
 164  			return nil, errors.New("failed to retrieve Private Key from Remote Signer Library")
 165  		}
 166  
 167  		minVersion, maxVersion, err := getTLSMinMaxVersionsServer(tlsConfig)
 168  		if err != nil {
 169  			return nil, err
 170  		}
 171  
 172  		clientAuth := getTLSClientAuthType(tlsConfig)
 173  
 174  		var cipherSuites []uint16
 175  		cipherSuites = getCipherSuites(tlsConfig.Ciphersuites)
 176  
 177  		// Create mTLS credentials for server.
 178  		return &tls.Config{
 179  			Certificates:           []tls.Certificate{cert},
 180  			VerifyPeerCertificate:  certverifier.VerifyClientCertificateChain(verificationMode, s2AStream),
 181  			ClientAuth:             clientAuth,
 182  			CipherSuites:           cipherSuites,
 183  			SessionTicketsDisabled: true,
 184  			MinVersion:             minVersion,
 185  			MaxVersion:             maxVersion,
 186  			NextProtos:             []string{h2},
 187  		}, nil
 188  	}
 189  }
 190  
 191  func getCipherSuites(tlsConfigCipherSuites []commonpb.Ciphersuite) []uint16 {
 192  	var tlsGoCipherSuites []uint16
 193  	for _, v := range tlsConfigCipherSuites {
 194  		s := getTLSCipherSuite(v)
 195  		if s != 0xffff {
 196  			tlsGoCipherSuites = append(tlsGoCipherSuites, s)
 197  		}
 198  	}
 199  	return tlsGoCipherSuites
 200  }
 201  
 202  func getTLSCipherSuite(tlsCipherSuite commonpb.Ciphersuite) uint16 {
 203  	switch tlsCipherSuite {
 204  	case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
 205  		return tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
 206  	case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384:
 207  		return tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
 208  	case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256:
 209  		return tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
 210  	case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
 211  		return tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
 212  	case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_RSA_WITH_AES_256_GCM_SHA384:
 213  		return tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
 214  	case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256:
 215  		return tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
 216  	default:
 217  		return 0xffff
 218  	}
 219  }
 220  
 221  func getServerConfigFromS2Av2(tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpb.Identity, sni string, s2AStream stream.S2AStream) (*s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration, error) {
 222  	authMechanisms := getAuthMechanisms(tokenManager, localIdentities)
 223  	var locID *commonpb.Identity
 224  	if localIdentities != nil {
 225  		locID = localIdentities[0]
 226  	}
 227  
 228  	if err := s2AStream.Send(&s2av2pb.SessionReq{
 229  		LocalIdentity:            locID,
 230  		AuthenticationMechanisms: authMechanisms,
 231  		ReqOneof: &s2av2pb.SessionReq_GetTlsConfigurationReq{
 232  			GetTlsConfigurationReq: &s2av2pb.GetTlsConfigurationReq{
 233  				ConnectionSide: commonpb.ConnectionSide_CONNECTION_SIDE_SERVER,
 234  				Sni:            sni,
 235  			},
 236  		},
 237  	}); err != nil {
 238  		return nil, err
 239  	}
 240  
 241  	resp, err := s2AStream.Recv()
 242  	if err != nil {
 243  		return nil, err
 244  	}
 245  
 246  	// TODO(rmehta19): Add unit test for this if statement.
 247  	if (resp.GetStatus() != nil) && (resp.GetStatus().Code != uint32(codes.OK)) {
 248  		return nil, fmt.Errorf("failed to get TLS configuration from S2A: %d, %v", resp.GetStatus().Code, resp.GetStatus().Details)
 249  	}
 250  
 251  	return resp.GetGetTlsConfigurationResp().GetServerTlsConfiguration(), nil
 252  }
 253  
 254  func getTLSClientAuthType(tlsConfig *s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration) tls.ClientAuthType {
 255  	var clientAuth tls.ClientAuthType
 256  	switch x := tlsConfig.RequestClientCertificate; x {
 257  	case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_DONT_REQUEST_CLIENT_CERTIFICATE:
 258  		clientAuth = tls.NoClientCert
 259  	case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
 260  		clientAuth = tls.RequestClientCert
 261  	case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY:
 262  		// This case actually maps to tls.VerifyClientCertIfGiven. However this
 263  		// mapping triggers normal verification, followed by custom verification,
 264  		// specified in VerifyPeerCertificate. To bypass normal verification, and
 265  		// only do custom verification we set clientAuth to RequireAnyClientCert or
 266  		// RequestClientCert. See https://github.com/google/s2a-go/pull/43 for full
 267  		// discussion.
 268  		clientAuth = tls.RequireAnyClientCert
 269  	case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
 270  		clientAuth = tls.RequireAnyClientCert
 271  	case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY:
 272  		// This case actually maps to tls.RequireAndVerifyClientCert. However this
 273  		// mapping triggers normal verification, followed by custom verification,
 274  		// specified in VerifyPeerCertificate. To bypass normal verification, and
 275  		// only do custom verification we set clientAuth to RequireAnyClientCert or
 276  		// RequestClientCert. See https://github.com/google/s2a-go/pull/43 for full
 277  		// discussion.
 278  		clientAuth = tls.RequireAnyClientCert
 279  	default:
 280  		clientAuth = tls.RequireAnyClientCert
 281  	}
 282  	return clientAuth
 283  }
 284  
 285  func getAuthMechanisms(tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpb.Identity) []*s2av2pb.AuthenticationMechanism {
 286  	if tokenManager == nil {
 287  		return nil
 288  	}
 289  	if len(localIdentities) == 0 {
 290  		token, err := tokenManager.DefaultToken()
 291  		if err != nil {
 292  			grpclog.Infof("Unable to get token for empty local identity: %v", err)
 293  			return nil
 294  		}
 295  		return []*s2av2pb.AuthenticationMechanism{
 296  			{
 297  				MechanismOneof: &s2av2pb.AuthenticationMechanism_Token{
 298  					Token: token,
 299  				},
 300  			},
 301  		}
 302  	}
 303  	var authMechanisms []*s2av2pb.AuthenticationMechanism
 304  	for _, localIdentity := range localIdentities {
 305  		if localIdentity == nil {
 306  			token, err := tokenManager.DefaultToken()
 307  			if err != nil {
 308  				grpclog.Infof("Unable to get default token for local identity %v: %v", localIdentity, err)
 309  				continue
 310  			}
 311  			authMechanisms = append(authMechanisms, &s2av2pb.AuthenticationMechanism{
 312  				Identity: localIdentity,
 313  				MechanismOneof: &s2av2pb.AuthenticationMechanism_Token{
 314  					Token: token,
 315  				},
 316  			})
 317  		} else {
 318  			token, err := tokenManager.Token(localIdentity)
 319  			if err != nil {
 320  				grpclog.Infof("Unable to get token for local identity %v: %v", localIdentity, err)
 321  				continue
 322  			}
 323  			authMechanisms = append(authMechanisms, &s2av2pb.AuthenticationMechanism{
 324  				Identity: localIdentity,
 325  				MechanismOneof: &s2av2pb.AuthenticationMechanism_Token{
 326  					Token: token,
 327  				},
 328  			})
 329  		}
 330  	}
 331  	return authMechanisms
 332  }
 333  
 334  // TODO(rmehta19): refactor switch statements into a helper function.
 335  func getTLSMinMaxVersionsClient(tlsConfig *s2av2pb.GetTlsConfigurationResp_ClientTlsConfiguration) (uint16, uint16, error) {
 336  	// Map S2Av2 TLSVersion to consts defined in tls package.
 337  	var minVersion uint16
 338  	var maxVersion uint16
 339  	switch x := tlsConfig.MinTlsVersion; x {
 340  	case commonpb.TLSVersion_TLS_VERSION_1_0:
 341  		minVersion = tls.VersionTLS10
 342  	case commonpb.TLSVersion_TLS_VERSION_1_1:
 343  		minVersion = tls.VersionTLS11
 344  	case commonpb.TLSVersion_TLS_VERSION_1_2:
 345  		minVersion = tls.VersionTLS12
 346  	case commonpb.TLSVersion_TLS_VERSION_1_3:
 347  		minVersion = tls.VersionTLS13
 348  	default:
 349  		return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MinTlsVersion: %v", x)
 350  	}
 351  
 352  	switch x := tlsConfig.MaxTlsVersion; x {
 353  	case commonpb.TLSVersion_TLS_VERSION_1_0:
 354  		maxVersion = tls.VersionTLS10
 355  	case commonpb.TLSVersion_TLS_VERSION_1_1:
 356  		maxVersion = tls.VersionTLS11
 357  	case commonpb.TLSVersion_TLS_VERSION_1_2:
 358  		maxVersion = tls.VersionTLS12
 359  	case commonpb.TLSVersion_TLS_VERSION_1_3:
 360  		maxVersion = tls.VersionTLS13
 361  	default:
 362  		return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MaxTlsVersion: %v", x)
 363  	}
 364  	if minVersion > maxVersion {
 365  		return minVersion, maxVersion, errors.New("S2Av2 provided minVersion > maxVersion")
 366  	}
 367  	return minVersion, maxVersion, nil
 368  }
 369  
 370  func getTLSMinMaxVersionsServer(tlsConfig *s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration) (uint16, uint16, error) {
 371  	// Map S2Av2 TLSVersion to consts defined in tls package.
 372  	var minVersion uint16
 373  	var maxVersion uint16
 374  	switch x := tlsConfig.MinTlsVersion; x {
 375  	case commonpb.TLSVersion_TLS_VERSION_1_0:
 376  		minVersion = tls.VersionTLS10
 377  	case commonpb.TLSVersion_TLS_VERSION_1_1:
 378  		minVersion = tls.VersionTLS11
 379  	case commonpb.TLSVersion_TLS_VERSION_1_2:
 380  		minVersion = tls.VersionTLS12
 381  	case commonpb.TLSVersion_TLS_VERSION_1_3:
 382  		minVersion = tls.VersionTLS13
 383  	default:
 384  		return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MinTlsVersion: %v", x)
 385  	}
 386  
 387  	switch x := tlsConfig.MaxTlsVersion; x {
 388  	case commonpb.TLSVersion_TLS_VERSION_1_0:
 389  		maxVersion = tls.VersionTLS10
 390  	case commonpb.TLSVersion_TLS_VERSION_1_1:
 391  		maxVersion = tls.VersionTLS11
 392  	case commonpb.TLSVersion_TLS_VERSION_1_2:
 393  		maxVersion = tls.VersionTLS12
 394  	case commonpb.TLSVersion_TLS_VERSION_1_3:
 395  		maxVersion = tls.VersionTLS13
 396  	default:
 397  		return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MaxTlsVersion: %v", x)
 398  	}
 399  	if minVersion > maxVersion {
 400  		return minVersion, maxVersion, errors.New("S2Av2 provided minVersion > maxVersion")
 401  	}
 402  	return minVersion, maxVersion, nil
 403  }
 404