handshaker.go raw

   1  /*
   2   *
   3   * Copyright 2021 Google LLC
   4   *
   5   * Licensed under the Apache License, Version 2.0 (the "License");
   6   * you may not use this file except in compliance with the License.
   7   * You may obtain a copy of the License at
   8   *
   9   *     https://www.apache.org/licenses/LICENSE-2.0
  10   *
  11   * Unless required by applicable law or agreed to in writing, software
  12   * distributed under the License is distributed on an "AS IS" BASIS,
  13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14   * See the License for the specific language governing permissions and
  15   * limitations under the License.
  16   *
  17   */
  18  
  19  // Package handshaker communicates with the S2A handshaker service.
  20  package handshaker
  21  
  22  import (
  23  	"context"
  24  	"errors"
  25  	"fmt"
  26  	"io"
  27  	"net"
  28  	"sync"
  29  
  30  	"github.com/google/s2a-go/internal/authinfo"
  31  	commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
  32  	s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
  33  	"github.com/google/s2a-go/internal/record"
  34  	"github.com/google/s2a-go/internal/tokenmanager"
  35  	grpc "google.golang.org/grpc"
  36  	"google.golang.org/grpc/codes"
  37  	"google.golang.org/grpc/credentials"
  38  	"google.golang.org/grpc/grpclog"
  39  )
  40  
  41  var (
  42  	// appProtocol contains the application protocol accepted by the handshaker.
  43  	appProtocol = "grpc"
  44  	// frameLimit is the maximum size of a frame in bytes.
  45  	frameLimit = 1024 * 64
  46  	// peerNotRespondingError is the error thrown when the peer doesn't respond.
  47  	errPeerNotResponding = errors.New("peer is not responding and re-connection should be attempted")
  48  )
  49  
  50  // Handshaker defines a handshaker interface.
  51  type Handshaker interface {
  52  	// ClientHandshake starts and completes a TLS handshake from the client side,
  53  	// and returns a secure connection along with additional auth information.
  54  	ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error)
  55  	// ServerHandshake starts and completes a TLS handshake from the server side,
  56  	// and returns a secure connection along with additional auth information.
  57  	ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error)
  58  	// Close terminates the Handshaker. It should be called when the handshake
  59  	// is complete.
  60  	Close() error
  61  }
  62  
  63  // ClientHandshakerOptions contains the options needed to configure the S2A
  64  // handshaker service on the client-side.
  65  type ClientHandshakerOptions struct {
  66  	// MinTLSVersion specifies the min TLS version supported by the client.
  67  	MinTLSVersion commonpb.TLSVersion
  68  	// MaxTLSVersion specifies the max TLS version supported by the client.
  69  	MaxTLSVersion commonpb.TLSVersion
  70  	// TLSCiphersuites is the ordered list of ciphersuites supported by the
  71  	// client.
  72  	TLSCiphersuites []commonpb.Ciphersuite
  73  	// TargetIdentities contains a list of allowed server identities. One of the
  74  	// target identities should match the peer identity in the handshake
  75  	// result; otherwise, the handshake fails.
  76  	TargetIdentities []*commonpb.Identity
  77  	// LocalIdentity is the local identity of the client application. If none is
  78  	// provided, then the S2A will choose the default identity.
  79  	LocalIdentity *commonpb.Identity
  80  	// TargetName is the allowed server name, which may be used for server
  81  	// authorization check by the S2A if it is provided.
  82  	TargetName string
  83  	// EnsureProcessSessionTickets allows users to wait and ensure that all
  84  	// available session tickets are sent to S2A before a process completes.
  85  	EnsureProcessSessionTickets *sync.WaitGroup
  86  }
  87  
  88  // ServerHandshakerOptions contains the options needed to configure the S2A
  89  // handshaker service on the server-side.
  90  type ServerHandshakerOptions struct {
  91  	// MinTLSVersion specifies the min TLS version supported by the server.
  92  	MinTLSVersion commonpb.TLSVersion
  93  	// MaxTLSVersion specifies the max TLS version supported by the server.
  94  	MaxTLSVersion commonpb.TLSVersion
  95  	// TLSCiphersuites is the ordered list of ciphersuites supported by the
  96  	// server.
  97  	TLSCiphersuites []commonpb.Ciphersuite
  98  	// LocalIdentities is the list of local identities that may be assumed by
  99  	// the server. If no local identity is specified, then the S2A chooses a
 100  	// default local identity.
 101  	LocalIdentities []*commonpb.Identity
 102  }
 103  
 104  // s2aHandshaker performs a TLS handshake using the S2A handshaker service.
 105  type s2aHandshaker struct {
 106  	// stream is used to communicate with the S2A handshaker service.
 107  	stream s2apb.S2AService_SetUpSessionClient
 108  	// conn is the connection to the peer.
 109  	conn net.Conn
 110  	// clientOpts should be non-nil iff the handshaker is client-side.
 111  	clientOpts *ClientHandshakerOptions
 112  	// serverOpts should be non-nil iff the handshaker is server-side.
 113  	serverOpts *ServerHandshakerOptions
 114  	// isClient determines if the handshaker is client or server side.
 115  	isClient bool
 116  	// hsAddr stores the address of the S2A handshaker service.
 117  	hsAddr string
 118  	// tokenManager manages access tokens for authenticating to S2A.
 119  	tokenManager tokenmanager.AccessTokenManager
 120  	// localIdentities is the set of local identities for whom the
 121  	// tokenManager should fetch a token when preparing a request to be
 122  	// sent to S2A.
 123  	localIdentities []*commonpb.Identity
 124  }
 125  
 126  // NewClientHandshaker creates an s2aHandshaker instance that performs a
 127  // client-side TLS handshake using the S2A handshaker service.
 128  func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, hsAddr string, opts *ClientHandshakerOptions) (Handshaker, error) {
 129  	stream, err := s2apb.NewS2AServiceClient(conn).SetUpSession(ctx, grpc.WaitForReady(true))
 130  	if err != nil {
 131  		return nil, err
 132  	}
 133  	tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
 134  	if err != nil {
 135  		grpclog.Infof("failed to create single token access token manager: %v", err)
 136  	}
 137  	return newClientHandshaker(stream, c, hsAddr, opts, tokenManager), nil
 138  }
 139  
 140  func newClientHandshaker(stream s2apb.S2AService_SetUpSessionClient, c net.Conn, hsAddr string, opts *ClientHandshakerOptions, tokenManager tokenmanager.AccessTokenManager) *s2aHandshaker {
 141  	var localIdentities []*commonpb.Identity
 142  	if opts != nil {
 143  		localIdentities = []*commonpb.Identity{opts.LocalIdentity}
 144  	}
 145  	return &s2aHandshaker{
 146  		stream:          stream,
 147  		conn:            c,
 148  		clientOpts:      opts,
 149  		isClient:        true,
 150  		hsAddr:          hsAddr,
 151  		tokenManager:    tokenManager,
 152  		localIdentities: localIdentities,
 153  	}
 154  }
 155  
 156  // NewServerHandshaker creates an s2aHandshaker instance that performs a
 157  // server-side TLS handshake using the S2A handshaker service.
 158  func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, hsAddr string, opts *ServerHandshakerOptions) (Handshaker, error) {
 159  	stream, err := s2apb.NewS2AServiceClient(conn).SetUpSession(ctx, grpc.WaitForReady(true))
 160  	if err != nil {
 161  		return nil, err
 162  	}
 163  	tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
 164  	if err != nil {
 165  		grpclog.Infof("failed to create single token access token manager: %v", err)
 166  	}
 167  	return newServerHandshaker(stream, c, hsAddr, opts, tokenManager), nil
 168  }
 169  
 170  func newServerHandshaker(stream s2apb.S2AService_SetUpSessionClient, c net.Conn, hsAddr string, opts *ServerHandshakerOptions, tokenManager tokenmanager.AccessTokenManager) *s2aHandshaker {
 171  	var localIdentities []*commonpb.Identity
 172  	if opts != nil {
 173  		localIdentities = opts.LocalIdentities
 174  	}
 175  	return &s2aHandshaker{
 176  		stream:          stream,
 177  		conn:            c,
 178  		serverOpts:      opts,
 179  		isClient:        false,
 180  		hsAddr:          hsAddr,
 181  		tokenManager:    tokenManager,
 182  		localIdentities: localIdentities,
 183  	}
 184  }
 185  
 186  // ClientHandshake performs a client-side TLS handshake using the S2A handshaker
 187  // service. When complete, returns a TLS connection.
 188  func (h *s2aHandshaker) ClientHandshake(_ context.Context) (net.Conn, credentials.AuthInfo, error) {
 189  	if !h.isClient {
 190  		return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client-side handshake")
 191  	}
 192  	// Extract the hostname from the target name. The target name is assumed to be an authority.
 193  	hostname, _, err := net.SplitHostPort(h.clientOpts.TargetName)
 194  	if err != nil {
 195  		// If the target name had no host port or could not be parsed, use it as is.
 196  		hostname = h.clientOpts.TargetName
 197  	}
 198  
 199  	// Prepare a client start message to send to the S2A handshaker service.
 200  	req := &s2apb.SessionReq{
 201  		ReqOneof: &s2apb.SessionReq_ClientStart{
 202  			ClientStart: &s2apb.ClientSessionStartReq{
 203  				ApplicationProtocols: []string{appProtocol},
 204  				MinTlsVersion:        h.clientOpts.MinTLSVersion,
 205  				MaxTlsVersion:        h.clientOpts.MaxTLSVersion,
 206  				TlsCiphersuites:      h.clientOpts.TLSCiphersuites,
 207  				TargetIdentities:     h.clientOpts.TargetIdentities,
 208  				LocalIdentity:        h.clientOpts.LocalIdentity,
 209  				TargetName:           hostname,
 210  			},
 211  		},
 212  		AuthMechanisms: h.getAuthMechanisms(),
 213  	}
 214  	conn, result, err := h.setUpSession(req)
 215  	if err != nil {
 216  		return nil, nil, err
 217  	}
 218  	authInfo, err := authinfo.NewS2AAuthInfo(result)
 219  	if err != nil {
 220  		return nil, nil, err
 221  	}
 222  	return conn, authInfo, nil
 223  }
 224  
 225  // ServerHandshake performs a server-side TLS handshake using the S2A handshaker
 226  // service. When complete, returns a TLS connection.
 227  func (h *s2aHandshaker) ServerHandshake(_ context.Context) (net.Conn, credentials.AuthInfo, error) {
 228  	if h.isClient {
 229  		return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server-side handshake")
 230  	}
 231  	p := make([]byte, frameLimit)
 232  	n, err := h.conn.Read(p)
 233  	if err != nil {
 234  		return nil, nil, err
 235  	}
 236  	// Prepare a server start message to send to the S2A handshaker service.
 237  	req := &s2apb.SessionReq{
 238  		ReqOneof: &s2apb.SessionReq_ServerStart{
 239  			ServerStart: &s2apb.ServerSessionStartReq{
 240  				ApplicationProtocols: []string{appProtocol},
 241  				MinTlsVersion:        h.serverOpts.MinTLSVersion,
 242  				MaxTlsVersion:        h.serverOpts.MaxTLSVersion,
 243  				TlsCiphersuites:      h.serverOpts.TLSCiphersuites,
 244  				LocalIdentities:      h.serverOpts.LocalIdentities,
 245  				InBytes:              p[:n],
 246  			},
 247  		},
 248  		AuthMechanisms: h.getAuthMechanisms(),
 249  	}
 250  	conn, result, err := h.setUpSession(req)
 251  	if err != nil {
 252  		return nil, nil, err
 253  	}
 254  	authInfo, err := authinfo.NewS2AAuthInfo(result)
 255  	if err != nil {
 256  		return nil, nil, err
 257  	}
 258  	return conn, authInfo, nil
 259  }
 260  
 261  // setUpSession proxies messages between the peer and the S2A handshaker
 262  // service.
 263  func (h *s2aHandshaker) setUpSession(req *s2apb.SessionReq) (net.Conn, *s2apb.SessionResult, error) {
 264  	resp, err := h.accessHandshakerService(req)
 265  	if err != nil {
 266  		return nil, nil, err
 267  	}
 268  	// Check if the returned status is an error.
 269  	if resp.GetStatus() != nil {
 270  		if got, want := resp.GetStatus().Code, uint32(codes.OK); got != want {
 271  			return nil, nil, fmt.Errorf("%v", resp.GetStatus().Details)
 272  		}
 273  	}
 274  	// Calculate the extra unread bytes from the Session. Attempting to consume
 275  	// more than the bytes sent will throw an error.
 276  	var extra []byte
 277  	if req.GetServerStart() != nil {
 278  		if resp.GetBytesConsumed() > uint32(len(req.GetServerStart().GetInBytes())) {
 279  			return nil, nil, errors.New("handshaker service consumed bytes value is out-of-bounds")
 280  		}
 281  		extra = req.GetServerStart().GetInBytes()[resp.GetBytesConsumed():]
 282  	}
 283  	result, extra, err := h.processUntilDone(resp, extra)
 284  	if err != nil {
 285  		return nil, nil, err
 286  	}
 287  	if result.GetLocalIdentity() == nil {
 288  		return nil, nil, errors.New("local identity must be populated in session result")
 289  	}
 290  
 291  	// Create a new TLS record protocol using the Session Result.
 292  	newConn, err := record.NewConn(&record.ConnParameters{
 293  		NetConn:                     h.conn,
 294  		Ciphersuite:                 result.GetState().GetTlsCiphersuite(),
 295  		TLSVersion:                  result.GetState().GetTlsVersion(),
 296  		InTrafficSecret:             result.GetState().GetInKey(),
 297  		OutTrafficSecret:            result.GetState().GetOutKey(),
 298  		UnusedBuf:                   extra,
 299  		InSequence:                  result.GetState().GetInSequence(),
 300  		OutSequence:                 result.GetState().GetOutSequence(),
 301  		HSAddr:                      h.hsAddr,
 302  		ConnectionID:                result.GetState().GetConnectionId(),
 303  		LocalIdentity:               result.GetLocalIdentity(),
 304  		EnsureProcessSessionTickets: h.ensureProcessSessionTickets(),
 305  	})
 306  	if err != nil {
 307  		return nil, nil, err
 308  	}
 309  	return newConn, result, nil
 310  }
 311  
 312  func (h *s2aHandshaker) ensureProcessSessionTickets() *sync.WaitGroup {
 313  	if h.clientOpts == nil {
 314  		return nil
 315  	}
 316  	return h.clientOpts.EnsureProcessSessionTickets
 317  }
 318  
 319  // accessHandshakerService sends the session request to the S2A handshaker
 320  // service and returns the session response.
 321  func (h *s2aHandshaker) accessHandshakerService(req *s2apb.SessionReq) (*s2apb.SessionResp, error) {
 322  	if err := h.stream.Send(req); err != nil {
 323  		return nil, err
 324  	}
 325  	resp, err := h.stream.Recv()
 326  	if err != nil {
 327  		return nil, err
 328  	}
 329  	return resp, nil
 330  }
 331  
 332  // processUntilDone continues proxying messages between the peer and the S2A
 333  // handshaker service until the handshaker service returns the SessionResult at
 334  // the end of the handshake or an error occurs.
 335  func (h *s2aHandshaker) processUntilDone(resp *s2apb.SessionResp, unusedBytes []byte) (*s2apb.SessionResult, []byte, error) {
 336  	for {
 337  		if len(resp.OutFrames) > 0 {
 338  			if _, err := h.conn.Write(resp.OutFrames); err != nil {
 339  				return nil, nil, err
 340  			}
 341  		}
 342  		if resp.Result != nil {
 343  			return resp.Result, unusedBytes, nil
 344  		}
 345  		buf := make([]byte, frameLimit)
 346  		n, err := h.conn.Read(buf)
 347  		if err != nil && err != io.EOF {
 348  			return nil, nil, err
 349  		}
 350  		// If there is nothing to send to the handshaker service and nothing is
 351  		// received from the peer, then we are stuck. This covers the case when
 352  		// the peer is not responding. Note that handshaker service connection
 353  		// issues are caught in accessHandshakerService before we even get
 354  		// here.
 355  		if len(resp.OutFrames) == 0 && n == 0 {
 356  			return nil, nil, errPeerNotResponding
 357  		}
 358  		// Append extra bytes from the previous interaction with the handshaker
 359  		// service with the current buffer read from conn.
 360  		p := append(unusedBytes, buf[:n]...)
 361  		// From here on, p and unusedBytes point to the same slice.
 362  		resp, err = h.accessHandshakerService(&s2apb.SessionReq{
 363  			ReqOneof: &s2apb.SessionReq_Next{
 364  				Next: &s2apb.SessionNextReq{
 365  					InBytes: p,
 366  				},
 367  			},
 368  			AuthMechanisms: h.getAuthMechanisms(),
 369  		})
 370  		if err != nil {
 371  			return nil, nil, err
 372  		}
 373  
 374  		// Cache the local identity returned by S2A, if it is populated. This
 375  		// overwrites any existing local identities. This is done because, once the
 376  		// S2A has selected a local identity, then only that local identity should
 377  		// be asserted in future requests until the end of the current handshake.
 378  		if resp.GetLocalIdentity() != nil {
 379  			h.localIdentities = []*commonpb.Identity{resp.GetLocalIdentity()}
 380  		}
 381  
 382  		// Set unusedBytes based on the handshaker service response.
 383  		if resp.GetBytesConsumed() > uint32(len(p)) {
 384  			return nil, nil, errors.New("handshaker service consumed bytes value is out-of-bounds")
 385  		}
 386  		unusedBytes = p[resp.GetBytesConsumed():]
 387  	}
 388  }
 389  
 390  // Close shuts down the handshaker and the stream to the S2A handshaker service
 391  // when the handshake is complete. It should be called when the caller obtains
 392  // the secure connection at the end of the handshake.
 393  func (h *s2aHandshaker) Close() error {
 394  	return h.stream.CloseSend()
 395  }
 396  
 397  func (h *s2aHandshaker) getAuthMechanisms() []*s2apb.AuthenticationMechanism {
 398  	if h.tokenManager == nil {
 399  		return nil
 400  	}
 401  	// First handle the special case when no local identities have been provided
 402  	// by the application. In this case, an AuthenticationMechanism with no local
 403  	// identity will be sent.
 404  	if len(h.localIdentities) == 0 {
 405  		token, err := h.tokenManager.DefaultToken()
 406  		if err != nil {
 407  			grpclog.Infof("unable to get token for empty local identity: %v", err)
 408  			return nil
 409  		}
 410  		return []*s2apb.AuthenticationMechanism{
 411  			{
 412  				MechanismOneof: &s2apb.AuthenticationMechanism_Token{
 413  					Token: token,
 414  				},
 415  			},
 416  		}
 417  	}
 418  
 419  	// Next, handle the case where the application (or the S2A) has provided
 420  	// one or more local identities.
 421  	var authMechanisms []*s2apb.AuthenticationMechanism
 422  	for _, localIdentity := range h.localIdentities {
 423  		token, err := h.tokenManager.Token(localIdentity)
 424  		if err != nil {
 425  			grpclog.Infof("unable to get token for local identity %v: %v", localIdentity, err)
 426  			continue
 427  		}
 428  
 429  		authMechanism := &s2apb.AuthenticationMechanism{
 430  			Identity: localIdentity,
 431  			MechanismOneof: &s2apb.AuthenticationMechanism_Token{
 432  				Token: token,
 433  			},
 434  		}
 435  		authMechanisms = append(authMechanisms, authMechanism)
 436  	}
 437  	return authMechanisms
 438  }
 439