s2av2.go raw

   1  /*
   2   *
   3   * Copyright 2022 Google LLC
   4   *
   5   * Licensed under the Apache License, Version 2.0 (the "License");
   6   * you may not use this file except in compliance with the License.
   7   * You may obtain a copy of the License at
   8   *
   9   *     https://www.apache.org/licenses/LICENSE-2.0
  10   *
  11   * Unless required by applicable law or agreed to in writing, software
  12   * distributed under the License is distributed on an "AS IS" BASIS,
  13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14   * See the License for the specific language governing permissions and
  15   * limitations under the License.
  16   *
  17   */
  18  
  19  // Package v2 provides the S2Av2 transport credentials used by a gRPC
  20  // application.
  21  package v2
  22  
  23  import (
  24  	"context"
  25  	"crypto/tls"
  26  	"errors"
  27  	"net"
  28  	"os"
  29  	"time"
  30  
  31  	"github.com/google/s2a-go/fallback"
  32  	"github.com/google/s2a-go/internal/handshaker/service"
  33  	"github.com/google/s2a-go/internal/tokenmanager"
  34  	"github.com/google/s2a-go/internal/v2/tlsconfigstore"
  35  	"github.com/google/s2a-go/retry"
  36  	"github.com/google/s2a-go/stream"
  37  	"google.golang.org/grpc"
  38  	"google.golang.org/grpc/credentials"
  39  	"google.golang.org/grpc/grpclog"
  40  	"google.golang.org/protobuf/proto"
  41  
  42  	commonpb "github.com/google/s2a-go/internal/proto/v2/common_go_proto"
  43  	s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
  44  )
  45  
  46  const (
  47  	s2aSecurityProtocol = "tls"
  48  	defaultS2ATimeout   = 6 * time.Second
  49  )
  50  
  51  // An environment variable, which sets the timeout enforced on the connection to the S2A service for handshake.
  52  const s2aTimeoutEnv = "S2A_TIMEOUT"
  53  
  54  type s2av2TransportCreds struct {
  55  	info           *credentials.ProtocolInfo
  56  	isClient       bool
  57  	serverName     string
  58  	s2av2Address   string
  59  	transportCreds credentials.TransportCredentials
  60  	tokenManager   *tokenmanager.AccessTokenManager
  61  	// localIdentity should only be used by the client.
  62  	localIdentity *commonpb.Identity
  63  	// localIdentities should only be used by the server.
  64  	localIdentities           []*commonpb.Identity
  65  	verificationMode          s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
  66  	fallbackClientHandshake   fallback.ClientHandshake
  67  	getS2AStream              stream.GetS2AStream
  68  	serverAuthorizationPolicy []byte
  69  }
  70  
  71  // NewClientCreds returns a client-side transport credentials object that uses
  72  // the S2Av2 to establish a secure connection with a server.
  73  func NewClientCreds(s2av2Address string, transportCreds credentials.TransportCredentials, localIdentity *commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, fallbackClientHandshakeFunc fallback.ClientHandshake, getS2AStream stream.GetS2AStream, serverAuthorizationPolicy []byte) (credentials.TransportCredentials, error) {
  74  	// Create an AccessTokenManager instance to use to authenticate to S2Av2.
  75  	accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
  76  
  77  	creds := &s2av2TransportCreds{
  78  		info: &credentials.ProtocolInfo{
  79  			SecurityProtocol: s2aSecurityProtocol,
  80  		},
  81  		isClient:                  true,
  82  		serverName:                "",
  83  		s2av2Address:              s2av2Address,
  84  		transportCreds:            transportCreds,
  85  		localIdentity:             localIdentity,
  86  		verificationMode:          verificationMode,
  87  		fallbackClientHandshake:   fallbackClientHandshakeFunc,
  88  		getS2AStream:              getS2AStream,
  89  		serverAuthorizationPolicy: serverAuthorizationPolicy,
  90  	}
  91  	if err != nil {
  92  		creds.tokenManager = nil
  93  	} else {
  94  		creds.tokenManager = &accessTokenManager
  95  	}
  96  	if grpclog.V(1) {
  97  		grpclog.Info("Created client S2Av2 transport credentials.")
  98  	}
  99  	return creds, nil
 100  }
 101  
 102  // NewServerCreds returns a server-side transport credentials object that uses
 103  // the S2Av2 to establish a secure connection with a client.
 104  func NewServerCreds(s2av2Address string, transportCreds credentials.TransportCredentials, localIdentities []*commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, getS2AStream stream.GetS2AStream) (credentials.TransportCredentials, error) {
 105  	// Create an AccessTokenManager instance to use to authenticate to S2Av2.
 106  	accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
 107  	creds := &s2av2TransportCreds{
 108  		info: &credentials.ProtocolInfo{
 109  			SecurityProtocol: s2aSecurityProtocol,
 110  		},
 111  		isClient:         false,
 112  		s2av2Address:     s2av2Address,
 113  		transportCreds:   transportCreds,
 114  		localIdentities:  localIdentities,
 115  		verificationMode: verificationMode,
 116  		getS2AStream:     getS2AStream,
 117  	}
 118  	if err != nil {
 119  		creds.tokenManager = nil
 120  	} else {
 121  		creds.tokenManager = &accessTokenManager
 122  	}
 123  	if grpclog.V(1) {
 124  		grpclog.Info("Created server S2Av2 transport credentials.")
 125  	}
 126  	return creds, nil
 127  }
 128  
 129  // ClientHandshake performs a client-side mTLS handshake using the S2Av2.
 130  func (c *s2av2TransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
 131  	if !c.isClient {
 132  		return nil, nil, errors.New("client handshake called using server transport credentials")
 133  	}
 134  	// Remove the port from serverAuthority.
 135  	serverName := removeServerNamePort(serverAuthority)
 136  	timeoutCtx, cancel := context.WithTimeout(ctx, GetS2ATimeout())
 137  	defer cancel()
 138  	var s2AStream stream.S2AStream
 139  	var err error
 140  	retry.Run(timeoutCtx,
 141  		func() error {
 142  			s2AStream, err = createStream(timeoutCtx, c.s2av2Address, c.transportCreds, c.getS2AStream)
 143  			return err
 144  		})
 145  	if err != nil {
 146  		grpclog.Infof("Failed to connect to S2Av2: %v", err)
 147  		if c.fallbackClientHandshake != nil {
 148  			return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
 149  		}
 150  		return nil, nil, err
 151  	}
 152  	defer s2AStream.CloseSend()
 153  	if grpclog.V(1) {
 154  		grpclog.Infof("Connected to S2Av2.")
 155  	}
 156  	var config *tls.Config
 157  
 158  	var tokenManager tokenmanager.AccessTokenManager
 159  	if c.tokenManager == nil {
 160  		tokenManager = nil
 161  	} else {
 162  		tokenManager = *c.tokenManager
 163  	}
 164  
 165  	sn := serverName
 166  	if c.serverName != "" {
 167  		sn = c.serverName
 168  	}
 169  	retry.Run(timeoutCtx,
 170  		func() error {
 171  			config, err = tlsconfigstore.GetTLSConfigurationForClient(sn, s2AStream, tokenManager, c.localIdentity, c.verificationMode, c.serverAuthorizationPolicy)
 172  			return err
 173  		})
 174  	if err != nil {
 175  		grpclog.Info("Failed to get client TLS config from S2Av2: %v", err)
 176  		if c.fallbackClientHandshake != nil {
 177  			return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
 178  		}
 179  		return nil, nil, err
 180  	}
 181  	if grpclog.V(1) {
 182  		grpclog.Infof("Got client TLS config from S2Av2.")
 183  	}
 184  
 185  	creds := credentials.NewTLS(config)
 186  	conn, authInfo, err := creds.ClientHandshake(timeoutCtx, serverName, rawConn)
 187  	if err != nil {
 188  		grpclog.Infof("Failed to do client handshake using S2Av2: %v", err)
 189  		if c.fallbackClientHandshake != nil {
 190  			return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
 191  		}
 192  		return nil, nil, err
 193  	}
 194  	grpclog.Infof("client-side handshake is done using S2Av2 to: %s", serverName)
 195  
 196  	return conn, authInfo, err
 197  }
 198  
 199  // ServerHandshake performs a server-side mTLS handshake using the S2Av2.
 200  func (c *s2av2TransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
 201  	if c.isClient {
 202  		return nil, nil, errors.New("server handshake called using client transport credentials")
 203  	}
 204  	ctx, cancel := context.WithTimeout(context.Background(), GetS2ATimeout())
 205  	defer cancel()
 206  	var s2AStream stream.S2AStream
 207  	var err error
 208  	retry.Run(ctx,
 209  		func() error {
 210  			s2AStream, err = createStream(ctx, c.s2av2Address, c.transportCreds, c.getS2AStream)
 211  			return err
 212  		})
 213  	if err != nil {
 214  		grpclog.Infof("Failed to connect to S2Av2: %v", err)
 215  		return nil, nil, err
 216  	}
 217  	defer s2AStream.CloseSend()
 218  	if grpclog.V(1) {
 219  		grpclog.Infof("Connected to S2Av2.")
 220  	}
 221  
 222  	var tokenManager tokenmanager.AccessTokenManager
 223  	if c.tokenManager == nil {
 224  		tokenManager = nil
 225  	} else {
 226  		tokenManager = *c.tokenManager
 227  	}
 228  
 229  	var config *tls.Config
 230  	retry.Run(ctx,
 231  		func() error {
 232  			config, err = tlsconfigstore.GetTLSConfigurationForServer(s2AStream, tokenManager, c.localIdentities, c.verificationMode)
 233  			return err
 234  		})
 235  	if err != nil {
 236  		grpclog.Infof("Failed to get server TLS config from S2Av2: %v", err)
 237  		return nil, nil, err
 238  	}
 239  	if grpclog.V(1) {
 240  		grpclog.Infof("Got server TLS config from S2Av2.")
 241  	}
 242  
 243  	creds := credentials.NewTLS(config)
 244  	conn, authInfo, err := creds.ServerHandshake(rawConn)
 245  	if err != nil {
 246  		grpclog.Infof("Failed to do server handshake using S2Av2: %v", err)
 247  		return nil, nil, err
 248  	}
 249  	return conn, authInfo, err
 250  }
 251  
 252  // Info returns protocol info of s2av2TransportCreds.
 253  func (c *s2av2TransportCreds) Info() credentials.ProtocolInfo {
 254  	return *c.info
 255  }
 256  
 257  // Clone makes a deep copy of s2av2TransportCreds.
 258  func (c *s2av2TransportCreds) Clone() credentials.TransportCredentials {
 259  	info := *c.info
 260  	serverName := c.serverName
 261  	fallbackClientHandshake := c.fallbackClientHandshake
 262  
 263  	s2av2Address := c.s2av2Address
 264  	var tokenManager tokenmanager.AccessTokenManager
 265  	if c.tokenManager == nil {
 266  		tokenManager = nil
 267  	} else {
 268  		tokenManager = *c.tokenManager
 269  	}
 270  	verificationMode := c.verificationMode
 271  	var localIdentity *commonpb.Identity
 272  	if c.localIdentity != nil {
 273  		localIdentity = proto.Clone(c.localIdentity).(*commonpb.Identity)
 274  	}
 275  	var localIdentities []*commonpb.Identity
 276  	if c.localIdentities != nil {
 277  		localIdentities = make([]*commonpb.Identity, len(c.localIdentities))
 278  		for i, localIdentity := range c.localIdentities {
 279  			localIdentities[i] = proto.Clone(localIdentity).(*commonpb.Identity)
 280  		}
 281  	}
 282  	creds := &s2av2TransportCreds{
 283  		info:                    &info,
 284  		isClient:                c.isClient,
 285  		serverName:              serverName,
 286  		fallbackClientHandshake: fallbackClientHandshake,
 287  		s2av2Address:            s2av2Address,
 288  		localIdentity:           localIdentity,
 289  		localIdentities:         localIdentities,
 290  		verificationMode:        verificationMode,
 291  	}
 292  	if c.tokenManager == nil {
 293  		creds.tokenManager = nil
 294  	} else {
 295  		creds.tokenManager = &tokenManager
 296  	}
 297  	return creds
 298  }
 299  
 300  // NewClientTLSConfig returns a tls.Config instance that uses S2Av2 to establish a TLS connection as
 301  // a client. The tls.Config MUST only be used to establish a single TLS connection.
 302  func NewClientTLSConfig(
 303  	ctx context.Context,
 304  	s2av2Address string,
 305  	transportCreds credentials.TransportCredentials,
 306  	tokenManager tokenmanager.AccessTokenManager,
 307  	verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode,
 308  	serverName string,
 309  	serverAuthorizationPolicy []byte,
 310  	getStream stream.GetS2AStream) (*tls.Config, error) {
 311  	s2AStream, err := createStream(ctx, s2av2Address, transportCreds, getStream)
 312  	if err != nil {
 313  		grpclog.Infof("Failed to connect to S2Av2: %v", err)
 314  		return nil, err
 315  	}
 316  
 317  	return tlsconfigstore.GetTLSConfigurationForClient(removeServerNamePort(serverName), s2AStream, tokenManager, nil, verificationMode, serverAuthorizationPolicy)
 318  }
 319  
 320  // OverrideServerName sets the ServerName in the s2av2TransportCreds protocol
 321  // info. The ServerName MUST be a hostname.
 322  func (c *s2av2TransportCreds) OverrideServerName(serverNameOverride string) error {
 323  	serverName := removeServerNamePort(serverNameOverride)
 324  	c.info.ServerName = serverName
 325  	c.serverName = serverName
 326  	return nil
 327  }
 328  
 329  // Remove the trailing port from server name.
 330  func removeServerNamePort(serverName string) string {
 331  	name, _, err := net.SplitHostPort(serverName)
 332  	if err != nil {
 333  		name = serverName
 334  	}
 335  	return name
 336  }
 337  
 338  type s2AGrpcStream struct {
 339  	stream s2av2pb.S2AService_SetUpSessionClient
 340  }
 341  
 342  func (x s2AGrpcStream) Send(m *s2av2pb.SessionReq) error {
 343  	return x.stream.Send(m)
 344  }
 345  
 346  func (x s2AGrpcStream) Recv() (*s2av2pb.SessionResp, error) {
 347  	return x.stream.Recv()
 348  }
 349  
 350  func (x s2AGrpcStream) CloseSend() error {
 351  	return x.stream.CloseSend()
 352  }
 353  
 354  func createStream(ctx context.Context, s2av2Address string, transportCreds credentials.TransportCredentials, getS2AStream stream.GetS2AStream) (stream.S2AStream, error) {
 355  	if getS2AStream != nil {
 356  		return getS2AStream(ctx, s2av2Address)
 357  	}
 358  	// TODO(rmehta19): Consider whether to close the connection to S2Av2.
 359  	conn, err := service.Dial(ctx, s2av2Address, transportCreds)
 360  	if err != nil {
 361  		return nil, err
 362  	}
 363  	client := s2av2pb.NewS2AServiceClient(conn)
 364  	gRPCStream, err := client.SetUpSession(ctx, []grpc.CallOption{}...)
 365  	if err != nil {
 366  		return nil, err
 367  	}
 368  	return &s2AGrpcStream{
 369  		stream: gRPCStream,
 370  	}, nil
 371  }
 372  
 373  // GetS2ATimeout returns the timeout enforced on the connection to the S2A service for handshake.
 374  func GetS2ATimeout() time.Duration {
 375  	timeout, err := time.ParseDuration(os.Getenv(s2aTimeoutEnv))
 376  	if err != nil {
 377  		return defaultS2ATimeout
 378  	}
 379  	return timeout
 380  }
 381