ticketsender.go raw

   1  /*
   2   *
   3   * Copyright 2021 Google LLC
   4   *
   5   * Licensed under the Apache License, Version 2.0 (the "License");
   6   * you may not use this file except in compliance with the License.
   7   * You may obtain a copy of the License at
   8   *
   9   *     https://www.apache.org/licenses/LICENSE-2.0
  10   *
  11   * Unless required by applicable law or agreed to in writing, software
  12   * distributed under the License is distributed on an "AS IS" BASIS,
  13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14   * See the License for the specific language governing permissions and
  15   * limitations under the License.
  16   *
  17   */
  18  
  19  package record
  20  
  21  import (
  22  	"context"
  23  	"fmt"
  24  	"sync"
  25  	"time"
  26  
  27  	"github.com/google/s2a-go/internal/handshaker/service"
  28  	commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
  29  	s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
  30  	"github.com/google/s2a-go/internal/tokenmanager"
  31  	"google.golang.org/grpc/codes"
  32  	"google.golang.org/grpc/grpclog"
  33  )
  34  
  35  // sessionTimeout is the timeout for creating a session with the S2A handshaker
  36  // service.
  37  const sessionTimeout = time.Second * 5
  38  
  39  // s2aTicketSender sends session tickets to the S2A handshaker service.
  40  type s2aTicketSender interface {
  41  	// sendTicketsToS2A sends the given session tickets to the S2A handshaker
  42  	// service.
  43  	sendTicketsToS2A(sessionTickets [][]byte, callComplete chan bool)
  44  }
  45  
  46  // ticketStream is the stream used to send and receive session information.
  47  type ticketStream interface {
  48  	Send(*s2apb.SessionReq) error
  49  	Recv() (*s2apb.SessionResp, error)
  50  }
  51  
  52  type ticketSender struct {
  53  	// hsAddr stores the address of the S2A handshaker service.
  54  	hsAddr string
  55  	// connectionID is the connection identifier that was created and sent by
  56  	// S2A at the end of a handshake.
  57  	connectionID uint64
  58  	// localIdentity is the local identity that was used by S2A during session
  59  	// setup and included in the session result.
  60  	localIdentity *commonpb.Identity
  61  	// tokenManager manages access tokens for authenticating to S2A.
  62  	tokenManager tokenmanager.AccessTokenManager
  63  	// ensureProcessSessionTickets allows users to wait and ensure that all
  64  	// available session tickets are sent to S2A before a process completes.
  65  	ensureProcessSessionTickets *sync.WaitGroup
  66  }
  67  
  68  // sendTicketsToS2A sends the given sessionTickets to the S2A handshaker
  69  // service. This is done asynchronously and writes to the error logs if an error
  70  // occurs.
  71  func (t *ticketSender) sendTicketsToS2A(sessionTickets [][]byte, callComplete chan bool) {
  72  	// Note that the goroutine is in the function rather than at the caller
  73  	// because the fake ticket sender used for testing must run synchronously
  74  	// so that the session tickets can be accessed from it after the tests have
  75  	// been run.
  76  	if t.ensureProcessSessionTickets != nil {
  77  		t.ensureProcessSessionTickets.Add(1)
  78  	}
  79  	go func() {
  80  		if err := func() error {
  81  			defer func() {
  82  				if t.ensureProcessSessionTickets != nil {
  83  					t.ensureProcessSessionTickets.Done()
  84  				}
  85  			}()
  86  			ctx, cancel := context.WithTimeout(context.Background(), sessionTimeout)
  87  			defer cancel()
  88  			// The transportCreds only needs to be set when talking to S2AV2 and also
  89  			// if mTLS is required.
  90  			hsConn, err := service.Dial(ctx, t.hsAddr, nil)
  91  			if err != nil {
  92  				return err
  93  			}
  94  			client := s2apb.NewS2AServiceClient(hsConn)
  95  			session, err := client.SetUpSession(ctx)
  96  			if err != nil {
  97  				return err
  98  			}
  99  			defer func() {
 100  				if err := session.CloseSend(); err != nil {
 101  					grpclog.Error(err)
 102  				}
 103  			}()
 104  			return t.writeTicketsToStream(session, sessionTickets)
 105  		}(); err != nil {
 106  			grpclog.Errorf("failed to send resumption tickets to S2A with identity: %v, %v",
 107  				t.localIdentity, err)
 108  		}
 109  		callComplete <- true
 110  		close(callComplete)
 111  	}()
 112  }
 113  
 114  // writeTicketsToStream writes the given session tickets to the given stream.
 115  func (t *ticketSender) writeTicketsToStream(stream ticketStream, sessionTickets [][]byte) error {
 116  	if err := stream.Send(
 117  		&s2apb.SessionReq{
 118  			ReqOneof: &s2apb.SessionReq_ResumptionTicket{
 119  				ResumptionTicket: &s2apb.ResumptionTicketReq{
 120  					InBytes:       sessionTickets,
 121  					ConnectionId:  t.connectionID,
 122  					LocalIdentity: t.localIdentity,
 123  				},
 124  			},
 125  			AuthMechanisms: t.getAuthMechanisms(),
 126  		},
 127  	); err != nil {
 128  		return err
 129  	}
 130  	sessionResp, err := stream.Recv()
 131  	if err != nil {
 132  		return err
 133  	}
 134  	if sessionResp.GetStatus().GetCode() != uint32(codes.OK) {
 135  		return fmt.Errorf("s2a session ticket response had error status: %v, %v",
 136  			sessionResp.GetStatus().GetCode(), sessionResp.GetStatus().GetDetails())
 137  	}
 138  	return nil
 139  }
 140  
 141  func (t *ticketSender) getAuthMechanisms() []*s2apb.AuthenticationMechanism {
 142  	if t.tokenManager == nil {
 143  		return nil
 144  	}
 145  	// First handle the special case when no local identity has been provided
 146  	// by the application. In this case, an AuthenticationMechanism with no local
 147  	// identity will be sent.
 148  	if t.localIdentity == nil {
 149  		token, err := t.tokenManager.DefaultToken()
 150  		if err != nil {
 151  			grpclog.Infof("unable to get token for empty local identity: %v", err)
 152  			return nil
 153  		}
 154  		return []*s2apb.AuthenticationMechanism{
 155  			{
 156  				MechanismOneof: &s2apb.AuthenticationMechanism_Token{
 157  					Token: token,
 158  				},
 159  			},
 160  		}
 161  	}
 162  
 163  	// Next, handle the case where the application (or the S2A) has specified
 164  	// a local identity.
 165  	token, err := t.tokenManager.Token(t.localIdentity)
 166  	if err != nil {
 167  		grpclog.Infof("unable to get token for local identity %v: %v", t.localIdentity, err)
 168  		return nil
 169  	}
 170  	return []*s2apb.AuthenticationMechanism{
 171  		{
 172  			Identity: t.localIdentity,
 173  			MechanismOneof: &s2apb.AuthenticationMechanism_Token{
 174  				Token: token,
 175  			},
 176  		},
 177  	}
 178  }
 179