s2a_fallback.go raw

   1  /*
   2   *
   3   * Copyright 2023 Google LLC
   4   *
   5   * Licensed under the Apache License, Version 2.0 (the "License");
   6   * you may not use this file except in compliance with the License.
   7   * You may obtain a copy of the License at
   8   *
   9   *     https://www.apache.org/licenses/LICENSE-2.0
  10   *
  11   * Unless required by applicable law or agreed to in writing, software
  12   * distributed under the License is distributed on an "AS IS" BASIS,
  13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14   * See the License for the specific language governing permissions and
  15   * limitations under the License.
  16   *
  17   */
  18  
  19  // Package fallback provides default implementations of fallback options when S2A fails.
  20  package fallback
  21  
  22  import (
  23  	"context"
  24  	"crypto/tls"
  25  	"fmt"
  26  	"net"
  27  
  28  	"google.golang.org/grpc/credentials"
  29  	"google.golang.org/grpc/grpclog"
  30  )
  31  
  32  const (
  33  	alpnProtoStrH2   = "h2"
  34  	alpnProtoStrHTTP = "http/1.1"
  35  	defaultHTTPSPort = "443"
  36  )
  37  
  38  // FallbackTLSConfigGRPC is a tls.Config used by the DefaultFallbackClientHandshakeFunc function.
  39  // It supports GRPC use case, thus the alpn is set to 'h2'.
  40  var FallbackTLSConfigGRPC = tls.Config{
  41  	MinVersion:         tls.VersionTLS13,
  42  	ClientSessionCache: nil,
  43  	NextProtos:         []string{alpnProtoStrH2},
  44  }
  45  
  46  // FallbackTLSConfigHTTP is a tls.Config used by the DefaultFallbackDialerAndAddress func.
  47  // It supports the HTTP use case and the alpn is set to both 'http/1.1' and 'h2'.
  48  var FallbackTLSConfigHTTP = tls.Config{
  49  	MinVersion:         tls.VersionTLS13,
  50  	ClientSessionCache: nil,
  51  	NextProtos:         []string{alpnProtoStrH2, alpnProtoStrHTTP},
  52  }
  53  
  54  // ClientHandshake establishes a TLS connection and returns it, plus its auth info.
  55  // Inputs:
  56  //
  57  //	targetServer: the server attempted with S2A.
  58  //	conn: the tcp connection to the server at address targetServer that was passed into S2A's ClientHandshake func.
  59  //	            If fallback is successful, the `conn` should be closed.
  60  //	err: the error encountered when performing the client-side TLS handshake with S2A.
  61  type ClientHandshake func(ctx context.Context, targetServer string, conn net.Conn, err error) (net.Conn, credentials.AuthInfo, error)
  62  
  63  // DefaultFallbackClientHandshakeFunc returns a ClientHandshake function,
  64  // which establishes a TLS connection to the provided fallbackAddr, returns the new connection and its auth info.
  65  // Example use:
  66  //
  67  //	transportCreds, _ = s2a.NewClientCreds(&s2a.ClientOptions{
  68  //		S2AAddress: s2aAddress,
  69  //		FallbackOpts: &s2a.FallbackOptions{ // optional
  70  //			FallbackClientHandshakeFunc: fallback.DefaultFallbackClientHandshakeFunc(fallbackAddr),
  71  //		},
  72  //	})
  73  //
  74  // The fallback server's certificate must be verifiable using OS root store.
  75  // The fallbackAddr is expected to be a network address, e.g. example.com:port. If port is not specified,
  76  // it uses default port 443.
  77  // In the returned function's TLS config, ClientSessionCache is explicitly set to nil to disable TLS resumption,
  78  // and min TLS version is set to 1.3.
  79  func DefaultFallbackClientHandshakeFunc(fallbackAddr string) (ClientHandshake, error) {
  80  	var fallbackDialer = tls.Dialer{Config: &FallbackTLSConfigGRPC}
  81  	return defaultFallbackClientHandshakeFuncInternal(fallbackAddr, fallbackDialer.DialContext)
  82  }
  83  
  84  func defaultFallbackClientHandshakeFuncInternal(fallbackAddr string, dialContextFunc func(context.Context, string, string) (net.Conn, error)) (ClientHandshake, error) {
  85  	fallbackServerAddr, err := processFallbackAddr(fallbackAddr)
  86  	if err != nil {
  87  		if grpclog.V(1) {
  88  			grpclog.Infof("error processing fallback address [%s]: %v", fallbackAddr, err)
  89  		}
  90  		return nil, err
  91  	}
  92  	return func(ctx context.Context, targetServer string, conn net.Conn, s2aErr error) (net.Conn, credentials.AuthInfo, error) {
  93  		fbConn, fbErr := dialContextFunc(ctx, "tcp", fallbackServerAddr)
  94  		if fbErr != nil {
  95  			grpclog.Infof("dialing to fallback server %s failed: %v", fallbackServerAddr, fbErr)
  96  			return nil, nil, fmt.Errorf("dialing to fallback server %s failed: %v; S2A client handshake with %s error: %w", fallbackServerAddr, fbErr, targetServer, s2aErr)
  97  		}
  98  
  99  		tc, success := fbConn.(*tls.Conn)
 100  		if !success {
 101  			grpclog.Infof("the connection with fallback server is expected to be tls but isn't")
 102  			return nil, nil, fmt.Errorf("the connection with fallback server is expected to be tls but isn't; S2A client handshake with %s error: %w", targetServer, s2aErr)
 103  		}
 104  
 105  		tlsInfo := credentials.TLSInfo{
 106  			State: tc.ConnectionState(),
 107  			CommonAuthInfo: credentials.CommonAuthInfo{
 108  				SecurityLevel: credentials.PrivacyAndIntegrity,
 109  			},
 110  		}
 111  		if grpclog.V(1) {
 112  			grpclog.Infof("ConnectionState.NegotiatedProtocol: %v", tc.ConnectionState().NegotiatedProtocol)
 113  			grpclog.Infof("ConnectionState.HandshakeComplete: %v", tc.ConnectionState().HandshakeComplete)
 114  			grpclog.Infof("ConnectionState.ServerName: %v", tc.ConnectionState().ServerName)
 115  		}
 116  		conn.Close()
 117  		return fbConn, tlsInfo, nil
 118  	}, nil
 119  }
 120  
 121  // DefaultFallbackDialerAndAddress returns a TLS dialer and the network address to dial.
 122  // Example use:
 123  //
 124  //	    fallbackDialer, fallbackServerAddr := fallback.DefaultFallbackDialerAndAddress(fallbackAddr)
 125  //		dialTLSContext := s2a.NewS2aDialTLSContextFunc(&s2a.ClientOptions{
 126  //			S2AAddress:         s2aAddress, // required
 127  //			FallbackOpts: &s2a.FallbackOptions{
 128  //				FallbackDialer: &s2a.FallbackDialer{
 129  //					Dialer:     fallbackDialer,
 130  //					ServerAddr: fallbackServerAddr,
 131  //				},
 132  //			},
 133  //	})
 134  //
 135  // The fallback server's certificate should be verifiable using OS root store.
 136  // The fallbackAddr is expected to be a network address, e.g. example.com:port. If port is not specified,
 137  // it uses default port 443.
 138  // In the returned function's TLS config, ClientSessionCache is explicitly set to nil to disable TLS resumption,
 139  // and min TLS version is set to 1.3.
 140  func DefaultFallbackDialerAndAddress(fallbackAddr string) (*tls.Dialer, string, error) {
 141  	fallbackServerAddr, err := processFallbackAddr(fallbackAddr)
 142  	if err != nil {
 143  		if grpclog.V(1) {
 144  			grpclog.Infof("error processing fallback address [%s]: %v", fallbackAddr, err)
 145  		}
 146  		return nil, "", err
 147  	}
 148  	return &tls.Dialer{Config: &FallbackTLSConfigHTTP}, fallbackServerAddr, nil
 149  }
 150  
 151  func processFallbackAddr(fallbackAddr string) (string, error) {
 152  	var fallbackServerAddr string
 153  	var err error
 154  
 155  	if fallbackAddr == "" {
 156  		return "", fmt.Errorf("empty fallback address")
 157  	}
 158  	_, _, err = net.SplitHostPort(fallbackAddr)
 159  	if err != nil {
 160  		// fallbackAddr does not have port suffix
 161  		fallbackServerAddr = net.JoinHostPort(fallbackAddr, defaultHTTPSPort)
 162  	} else {
 163  		// FallbackServerAddr already has port suffix
 164  		fallbackServerAddr = fallbackAddr
 165  	}
 166  	return fallbackServerAddr, nil
 167  }
 168