s2a.go raw
1 /*
2 *
3 * Copyright 2021 Google LLC
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19 // Package s2a provides the S2A transport credentials used by a gRPC
20 // application.
21 package s2a
22
23 import (
24 "context"
25 "crypto/tls"
26 "errors"
27 "fmt"
28 "net"
29 "sync"
30 "time"
31
32 "github.com/google/s2a-go/fallback"
33 "github.com/google/s2a-go/internal/handshaker"
34 "github.com/google/s2a-go/internal/handshaker/service"
35 "github.com/google/s2a-go/internal/tokenmanager"
36 "github.com/google/s2a-go/internal/v2"
37 "github.com/google/s2a-go/retry"
38 "github.com/google/s2a-go/stream"
39 "google.golang.org/grpc/credentials"
40 "google.golang.org/grpc/grpclog"
41 "google.golang.org/protobuf/proto"
42
43 commonpbv1 "github.com/google/s2a-go/internal/proto/common_go_proto"
44 commonpb "github.com/google/s2a-go/internal/proto/v2/common_go_proto"
45 s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
46 )
47
48 const (
49 s2aSecurityProtocol = "tls"
50 // defaultTimeout specifies the default server handshake timeout.
51 defaultTimeout = 30.0 * time.Second
52 )
53
54 // s2aTransportCreds are the transport credentials required for establishing
55 // a secure connection using the S2A. They implement the
56 // credentials.TransportCredentials interface.
57 type s2aTransportCreds struct {
58 info *credentials.ProtocolInfo
59 minTLSVersion commonpbv1.TLSVersion
60 maxTLSVersion commonpbv1.TLSVersion
61 // tlsCiphersuites contains the ciphersuites used in the S2A connection.
62 // Note that these are currently unconfigurable.
63 tlsCiphersuites []commonpbv1.Ciphersuite
64 // localIdentity should only be used by the client.
65 localIdentity *commonpbv1.Identity
66 // localIdentities should only be used by the server.
67 localIdentities []*commonpbv1.Identity
68 // targetIdentities should only be used by the client.
69 targetIdentities []*commonpbv1.Identity
70 isClient bool
71 s2aAddr string
72 ensureProcessSessionTickets *sync.WaitGroup
73 }
74
75 // NewClientCreds returns a client-side transport credentials object that uses
76 // the S2A to establish a secure connection with a server.
77 func NewClientCreds(opts *ClientOptions) (credentials.TransportCredentials, error) {
78 if opts == nil {
79 return nil, errors.New("nil client options")
80 }
81 var targetIdentities []*commonpbv1.Identity
82 for _, targetIdentity := range opts.TargetIdentities {
83 protoTargetIdentity, err := toProtoIdentity(targetIdentity)
84 if err != nil {
85 return nil, err
86 }
87 targetIdentities = append(targetIdentities, protoTargetIdentity)
88 }
89 localIdentity, err := toProtoIdentity(opts.LocalIdentity)
90 if err != nil {
91 return nil, err
92 }
93 if opts.EnableLegacyMode {
94 return &s2aTransportCreds{
95 info: &credentials.ProtocolInfo{
96 SecurityProtocol: s2aSecurityProtocol,
97 },
98 minTLSVersion: commonpbv1.TLSVersion_TLS1_3,
99 maxTLSVersion: commonpbv1.TLSVersion_TLS1_3,
100 tlsCiphersuites: []commonpbv1.Ciphersuite{
101 commonpbv1.Ciphersuite_AES_128_GCM_SHA256,
102 commonpbv1.Ciphersuite_AES_256_GCM_SHA384,
103 commonpbv1.Ciphersuite_CHACHA20_POLY1305_SHA256,
104 },
105 localIdentity: localIdentity,
106 targetIdentities: targetIdentities,
107 isClient: true,
108 s2aAddr: opts.S2AAddress,
109 ensureProcessSessionTickets: opts.EnsureProcessSessionTickets,
110 }, nil
111 }
112 verificationMode := getVerificationMode(opts.VerificationMode)
113 var fallbackFunc fallback.ClientHandshake
114 if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackClientHandshakeFunc != nil {
115 fallbackFunc = opts.FallbackOpts.FallbackClientHandshakeFunc
116 }
117 v2LocalIdentity, err := toV2ProtoIdentity(opts.LocalIdentity)
118 if err != nil {
119 return nil, err
120 }
121 return v2.NewClientCreds(opts.S2AAddress, opts.TransportCreds, v2LocalIdentity, verificationMode, fallbackFunc, opts.getS2AStream, opts.serverAuthorizationPolicy)
122 }
123
124 // NewServerCreds returns a server-side transport credentials object that uses
125 // the S2A to establish a secure connection with a client.
126 func NewServerCreds(opts *ServerOptions) (credentials.TransportCredentials, error) {
127 if opts == nil {
128 return nil, errors.New("nil server options")
129 }
130 var localIdentities []*commonpbv1.Identity
131 for _, localIdentity := range opts.LocalIdentities {
132 protoLocalIdentity, err := toProtoIdentity(localIdentity)
133 if err != nil {
134 return nil, err
135 }
136 localIdentities = append(localIdentities, protoLocalIdentity)
137 }
138 if opts.EnableLegacyMode {
139 return &s2aTransportCreds{
140 info: &credentials.ProtocolInfo{
141 SecurityProtocol: s2aSecurityProtocol,
142 },
143 minTLSVersion: commonpbv1.TLSVersion_TLS1_3,
144 maxTLSVersion: commonpbv1.TLSVersion_TLS1_3,
145 tlsCiphersuites: []commonpbv1.Ciphersuite{
146 commonpbv1.Ciphersuite_AES_128_GCM_SHA256,
147 commonpbv1.Ciphersuite_AES_256_GCM_SHA384,
148 commonpbv1.Ciphersuite_CHACHA20_POLY1305_SHA256,
149 },
150 localIdentities: localIdentities,
151 isClient: false,
152 s2aAddr: opts.S2AAddress,
153 }, nil
154 }
155 verificationMode := getVerificationMode(opts.VerificationMode)
156 var v2LocalIdentities []*commonpb.Identity
157 for _, localIdentity := range opts.LocalIdentities {
158 protoLocalIdentity, err := toV2ProtoIdentity(localIdentity)
159 if err != nil {
160 return nil, err
161 }
162 v2LocalIdentities = append(v2LocalIdentities, protoLocalIdentity)
163 }
164 return v2.NewServerCreds(opts.S2AAddress, opts.TransportCreds, v2LocalIdentities, verificationMode, opts.getS2AStream)
165 }
166
167 // ClientHandshake initiates a client-side TLS handshake using the S2A.
168 func (c *s2aTransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
169 if !c.isClient {
170 return nil, nil, errors.New("client handshake called using server transport credentials")
171 }
172
173 var cancel context.CancelFunc
174 ctx, cancel = context.WithCancel(ctx)
175 defer cancel()
176
177 // Connect to the S2A.
178 hsConn, err := service.Dial(ctx, c.s2aAddr, nil)
179 if err != nil {
180 grpclog.Infof("Failed to connect to S2A: %v", err)
181 return nil, nil, err
182 }
183
184 opts := &handshaker.ClientHandshakerOptions{
185 MinTLSVersion: c.minTLSVersion,
186 MaxTLSVersion: c.maxTLSVersion,
187 TLSCiphersuites: c.tlsCiphersuites,
188 TargetIdentities: c.targetIdentities,
189 LocalIdentity: c.localIdentity,
190 TargetName: serverAuthority,
191 EnsureProcessSessionTickets: c.ensureProcessSessionTickets,
192 }
193 chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
194 if err != nil {
195 grpclog.Infof("Call to handshaker.NewClientHandshaker failed: %v", err)
196 return nil, nil, err
197 }
198 defer func() {
199 if err != nil {
200 if closeErr := chs.Close(); closeErr != nil {
201 grpclog.Infof("Close failed unexpectedly: %v", err)
202 err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
203 }
204 }
205 }()
206
207 secConn, authInfo, err := chs.ClientHandshake(context.Background())
208 if err != nil {
209 grpclog.Infof("Handshake failed: %v", err)
210 return nil, nil, err
211 }
212 return secConn, authInfo, nil
213 }
214
215 // ServerHandshake initiates a server-side TLS handshake using the S2A.
216 func (c *s2aTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
217 if c.isClient {
218 return nil, nil, errors.New("server handshake called using client transport credentials")
219 }
220
221 ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
222 defer cancel()
223
224 // Connect to the S2A.
225 hsConn, err := service.Dial(ctx, c.s2aAddr, nil)
226 if err != nil {
227 grpclog.Infof("Failed to connect to S2A: %v", err)
228 return nil, nil, err
229 }
230
231 opts := &handshaker.ServerHandshakerOptions{
232 MinTLSVersion: c.minTLSVersion,
233 MaxTLSVersion: c.maxTLSVersion,
234 TLSCiphersuites: c.tlsCiphersuites,
235 LocalIdentities: c.localIdentities,
236 }
237 shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
238 if err != nil {
239 grpclog.Infof("Call to handshaker.NewServerHandshaker failed: %v", err)
240 return nil, nil, err
241 }
242 defer func() {
243 if err != nil {
244 if closeErr := shs.Close(); closeErr != nil {
245 grpclog.Infof("Close failed unexpectedly: %v", err)
246 err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
247 }
248 }
249 }()
250
251 secConn, authInfo, err := shs.ServerHandshake(context.Background())
252 if err != nil {
253 grpclog.Infof("Handshake failed: %v", err)
254 return nil, nil, err
255 }
256 return secConn, authInfo, nil
257 }
258
259 func (c *s2aTransportCreds) Info() credentials.ProtocolInfo {
260 return *c.info
261 }
262
263 func (c *s2aTransportCreds) Clone() credentials.TransportCredentials {
264 info := *c.info
265 var localIdentity *commonpbv1.Identity
266 if c.localIdentity != nil {
267 localIdentity = proto.Clone(c.localIdentity).(*commonpbv1.Identity)
268 }
269 var localIdentities []*commonpbv1.Identity
270 if c.localIdentities != nil {
271 localIdentities = make([]*commonpbv1.Identity, len(c.localIdentities))
272 for i, localIdentity := range c.localIdentities {
273 localIdentities[i] = proto.Clone(localIdentity).(*commonpbv1.Identity)
274 }
275 }
276 var targetIdentities []*commonpbv1.Identity
277 if c.targetIdentities != nil {
278 targetIdentities = make([]*commonpbv1.Identity, len(c.targetIdentities))
279 for i, targetIdentity := range c.targetIdentities {
280 targetIdentities[i] = proto.Clone(targetIdentity).(*commonpbv1.Identity)
281 }
282 }
283 return &s2aTransportCreds{
284 info: &info,
285 minTLSVersion: c.minTLSVersion,
286 maxTLSVersion: c.maxTLSVersion,
287 tlsCiphersuites: c.tlsCiphersuites,
288 localIdentity: localIdentity,
289 localIdentities: localIdentities,
290 targetIdentities: targetIdentities,
291 isClient: c.isClient,
292 s2aAddr: c.s2aAddr,
293 }
294 }
295
296 func (c *s2aTransportCreds) OverrideServerName(serverNameOverride string) error {
297 c.info.ServerName = serverNameOverride
298 return nil
299 }
300
301 // TLSClientConfigOptions specifies parameters for creating client TLS config.
302 type TLSClientConfigOptions struct {
303 // ServerName is required by s2a as the expected name when verifying the hostname found in server's certificate.
304 // tlsConfig, _ := factory.Build(ctx, &s2a.TLSClientConfigOptions{
305 // ServerName: "example.com",
306 // })
307 ServerName string
308 }
309
310 // TLSClientConfigFactory defines the interface for a client TLS config factory.
311 type TLSClientConfigFactory interface {
312 Build(ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error)
313 }
314
315 // NewTLSClientConfigFactory returns an instance of s2aTLSClientConfigFactory.
316 func NewTLSClientConfigFactory(opts *ClientOptions) (TLSClientConfigFactory, error) {
317 if opts == nil {
318 return nil, fmt.Errorf("opts must be non-nil")
319 }
320 if opts.EnableLegacyMode {
321 return nil, fmt.Errorf("NewTLSClientConfigFactory only supports S2Av2")
322 }
323 tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
324 if err != nil {
325 // The only possible error is: access token not set in the environment,
326 // which is okay in environments other than serverless.
327 grpclog.Infof("Access token manager not initialized: %v", err)
328 return &s2aTLSClientConfigFactory{
329 s2av2Address: opts.S2AAddress,
330 transportCreds: opts.TransportCreds,
331 tokenManager: nil,
332 verificationMode: getVerificationMode(opts.VerificationMode),
333 serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
334 getStream: opts.getS2AStream,
335 }, nil
336 }
337 return &s2aTLSClientConfigFactory{
338 s2av2Address: opts.S2AAddress,
339 transportCreds: opts.TransportCreds,
340 tokenManager: tokenManager,
341 verificationMode: getVerificationMode(opts.VerificationMode),
342 serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
343 getStream: opts.getS2AStream,
344 }, nil
345 }
346
347 type s2aTLSClientConfigFactory struct {
348 s2av2Address string
349 transportCreds credentials.TransportCredentials
350 tokenManager tokenmanager.AccessTokenManager
351 verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
352 serverAuthorizationPolicy []byte
353 getStream stream.GetS2AStream
354 }
355
356 func (f *s2aTLSClientConfigFactory) Build(
357 ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error) {
358 serverName := ""
359 if opts != nil && opts.ServerName != "" {
360 serverName = opts.ServerName
361 }
362 return v2.NewClientTLSConfig(ctx, f.s2av2Address, f.transportCreds, f.tokenManager, f.verificationMode, serverName, f.serverAuthorizationPolicy, f.getStream)
363 }
364
365 func getVerificationMode(verificationMode VerificationModeType) s2av2pb.ValidatePeerCertificateChainReq_VerificationMode {
366 switch verificationMode {
367 case ConnectToGoogle:
368 return s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE
369 case Spiffe:
370 return s2av2pb.ValidatePeerCertificateChainReq_SPIFFE
371 case ReservedCustomVerificationMode3:
372 return s2av2pb.ValidatePeerCertificateChainReq_RESERVED_CUSTOM_VERIFICATION_MODE_3
373 case ReservedCustomVerificationMode4:
374 return s2av2pb.ValidatePeerCertificateChainReq_RESERVED_CUSTOM_VERIFICATION_MODE_4
375 case ReservedCustomVerificationMode5:
376 return s2av2pb.ValidatePeerCertificateChainReq_RESERVED_CUSTOM_VERIFICATION_MODE_5
377 case ReservedCustomVerificationMode6:
378 return s2av2pb.ValidatePeerCertificateChainReq_RESERVED_CUSTOM_VERIFICATION_MODE_6
379 default:
380 return s2av2pb.ValidatePeerCertificateChainReq_UNSPECIFIED
381 }
382 }
383
384 // NewS2ADialTLSContextFunc returns a dialer which establishes an MTLS connection using S2A.
385 // Example use with http.RoundTripper:
386 //
387 // dialTLSContext := s2a.NewS2aDialTLSContextFunc(&s2a.ClientOptions{
388 // S2AAddress: s2aAddress, // required
389 // })
390 // transport := http.DefaultTransport
391 // transport.DialTLSContext = dialTLSContext
392 func NewS2ADialTLSContextFunc(opts *ClientOptions) func(ctx context.Context, network, addr string) (net.Conn, error) {
393
394 return func(ctx context.Context, network, addr string) (net.Conn, error) {
395
396 fallback := func(err error) (net.Conn, error) {
397 if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackDialer != nil &&
398 opts.FallbackOpts.FallbackDialer.Dialer != nil && opts.FallbackOpts.FallbackDialer.ServerAddr != "" {
399 fbDialer := opts.FallbackOpts.FallbackDialer
400 grpclog.Infof("fall back to dial: %s", fbDialer.ServerAddr)
401 fbConn, fbErr := fbDialer.Dialer.DialContext(ctx, network, fbDialer.ServerAddr)
402 if fbErr != nil {
403 return nil, fmt.Errorf("error fallback to %s: %v; S2A error: %w", fbDialer.ServerAddr, fbErr, err)
404 }
405 return fbConn, nil
406 }
407 return nil, err
408 }
409
410 factory, err := NewTLSClientConfigFactory(opts)
411 if err != nil {
412 grpclog.Infof("error creating S2A client config factory: %v", err)
413 return fallback(err)
414 }
415
416 serverName, _, err := net.SplitHostPort(addr)
417 if err != nil {
418 serverName = addr
419 }
420 timeoutCtx, cancel := context.WithTimeout(ctx, v2.GetS2ATimeout())
421 defer cancel()
422
423 var s2aTLSConfig *tls.Config
424 var c net.Conn
425 retry.Run(timeoutCtx,
426 func() error {
427 s2aTLSConfig, err = factory.Build(timeoutCtx, &TLSClientConfigOptions{
428 ServerName: serverName,
429 })
430 if err != nil {
431 grpclog.Infof("error building S2A TLS config: %v", err)
432 return err
433 }
434
435 s2aDialer := &tls.Dialer{
436 Config: s2aTLSConfig,
437 }
438 c, err = s2aDialer.DialContext(timeoutCtx, network, addr)
439 return err
440 })
441 if err != nil {
442 grpclog.Infof("error dialing with S2A to %s: %v", addr, err)
443 return fallback(err)
444 }
445 grpclog.Infof("success dialing MTLS to %s with S2A", addr)
446 return c, nil
447 }
448 }
449