s2av2.go raw
1 /*
2 *
3 * Copyright 2022 Google LLC
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19 // Package v2 provides the S2Av2 transport credentials used by a gRPC
20 // application.
21 package v2
22
23 import (
24 "context"
25 "crypto/tls"
26 "errors"
27 "net"
28 "os"
29 "time"
30
31 "github.com/google/s2a-go/fallback"
32 "github.com/google/s2a-go/internal/handshaker/service"
33 "github.com/google/s2a-go/internal/tokenmanager"
34 "github.com/google/s2a-go/internal/v2/tlsconfigstore"
35 "github.com/google/s2a-go/retry"
36 "github.com/google/s2a-go/stream"
37 "google.golang.org/grpc"
38 "google.golang.org/grpc/credentials"
39 "google.golang.org/grpc/grpclog"
40 "google.golang.org/protobuf/proto"
41
42 commonpb "github.com/google/s2a-go/internal/proto/v2/common_go_proto"
43 s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
44 )
45
46 const (
47 s2aSecurityProtocol = "tls"
48 defaultS2ATimeout = 6 * time.Second
49 )
50
51 // An environment variable, which sets the timeout enforced on the connection to the S2A service for handshake.
52 const s2aTimeoutEnv = "S2A_TIMEOUT"
53
54 type s2av2TransportCreds struct {
55 info *credentials.ProtocolInfo
56 isClient bool
57 serverName string
58 s2av2Address string
59 transportCreds credentials.TransportCredentials
60 tokenManager *tokenmanager.AccessTokenManager
61 // localIdentity should only be used by the client.
62 localIdentity *commonpb.Identity
63 // localIdentities should only be used by the server.
64 localIdentities []*commonpb.Identity
65 verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
66 fallbackClientHandshake fallback.ClientHandshake
67 getS2AStream stream.GetS2AStream
68 serverAuthorizationPolicy []byte
69 }
70
71 // NewClientCreds returns a client-side transport credentials object that uses
72 // the S2Av2 to establish a secure connection with a server.
73 func NewClientCreds(s2av2Address string, transportCreds credentials.TransportCredentials, localIdentity *commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, fallbackClientHandshakeFunc fallback.ClientHandshake, getS2AStream stream.GetS2AStream, serverAuthorizationPolicy []byte) (credentials.TransportCredentials, error) {
74 // Create an AccessTokenManager instance to use to authenticate to S2Av2.
75 accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
76
77 creds := &s2av2TransportCreds{
78 info: &credentials.ProtocolInfo{
79 SecurityProtocol: s2aSecurityProtocol,
80 },
81 isClient: true,
82 serverName: "",
83 s2av2Address: s2av2Address,
84 transportCreds: transportCreds,
85 localIdentity: localIdentity,
86 verificationMode: verificationMode,
87 fallbackClientHandshake: fallbackClientHandshakeFunc,
88 getS2AStream: getS2AStream,
89 serverAuthorizationPolicy: serverAuthorizationPolicy,
90 }
91 if err != nil {
92 creds.tokenManager = nil
93 } else {
94 creds.tokenManager = &accessTokenManager
95 }
96 if grpclog.V(1) {
97 grpclog.Info("Created client S2Av2 transport credentials.")
98 }
99 return creds, nil
100 }
101
102 // NewServerCreds returns a server-side transport credentials object that uses
103 // the S2Av2 to establish a secure connection with a client.
104 func NewServerCreds(s2av2Address string, transportCreds credentials.TransportCredentials, localIdentities []*commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, getS2AStream stream.GetS2AStream) (credentials.TransportCredentials, error) {
105 // Create an AccessTokenManager instance to use to authenticate to S2Av2.
106 accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
107 creds := &s2av2TransportCreds{
108 info: &credentials.ProtocolInfo{
109 SecurityProtocol: s2aSecurityProtocol,
110 },
111 isClient: false,
112 s2av2Address: s2av2Address,
113 transportCreds: transportCreds,
114 localIdentities: localIdentities,
115 verificationMode: verificationMode,
116 getS2AStream: getS2AStream,
117 }
118 if err != nil {
119 creds.tokenManager = nil
120 } else {
121 creds.tokenManager = &accessTokenManager
122 }
123 if grpclog.V(1) {
124 grpclog.Info("Created server S2Av2 transport credentials.")
125 }
126 return creds, nil
127 }
128
129 // ClientHandshake performs a client-side mTLS handshake using the S2Av2.
130 func (c *s2av2TransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
131 if !c.isClient {
132 return nil, nil, errors.New("client handshake called using server transport credentials")
133 }
134 // Remove the port from serverAuthority.
135 serverName := removeServerNamePort(serverAuthority)
136 timeoutCtx, cancel := context.WithTimeout(ctx, GetS2ATimeout())
137 defer cancel()
138 var s2AStream stream.S2AStream
139 var err error
140 retry.Run(timeoutCtx,
141 func() error {
142 s2AStream, err = createStream(timeoutCtx, c.s2av2Address, c.transportCreds, c.getS2AStream)
143 return err
144 })
145 if err != nil {
146 grpclog.Infof("Failed to connect to S2Av2: %v", err)
147 if c.fallbackClientHandshake != nil {
148 return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
149 }
150 return nil, nil, err
151 }
152 defer s2AStream.CloseSend()
153 if grpclog.V(1) {
154 grpclog.Infof("Connected to S2Av2.")
155 }
156 var config *tls.Config
157
158 var tokenManager tokenmanager.AccessTokenManager
159 if c.tokenManager == nil {
160 tokenManager = nil
161 } else {
162 tokenManager = *c.tokenManager
163 }
164
165 sn := serverName
166 if c.serverName != "" {
167 sn = c.serverName
168 }
169 retry.Run(timeoutCtx,
170 func() error {
171 config, err = tlsconfigstore.GetTLSConfigurationForClient(sn, s2AStream, tokenManager, c.localIdentity, c.verificationMode, c.serverAuthorizationPolicy)
172 return err
173 })
174 if err != nil {
175 grpclog.Info("Failed to get client TLS config from S2Av2: %v", err)
176 if c.fallbackClientHandshake != nil {
177 return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
178 }
179 return nil, nil, err
180 }
181 if grpclog.V(1) {
182 grpclog.Infof("Got client TLS config from S2Av2.")
183 }
184
185 creds := credentials.NewTLS(config)
186 conn, authInfo, err := creds.ClientHandshake(timeoutCtx, serverName, rawConn)
187 if err != nil {
188 grpclog.Infof("Failed to do client handshake using S2Av2: %v", err)
189 if c.fallbackClientHandshake != nil {
190 return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
191 }
192 return nil, nil, err
193 }
194 grpclog.Infof("client-side handshake is done using S2Av2 to: %s", serverName)
195
196 return conn, authInfo, err
197 }
198
199 // ServerHandshake performs a server-side mTLS handshake using the S2Av2.
200 func (c *s2av2TransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
201 if c.isClient {
202 return nil, nil, errors.New("server handshake called using client transport credentials")
203 }
204 ctx, cancel := context.WithTimeout(context.Background(), GetS2ATimeout())
205 defer cancel()
206 var s2AStream stream.S2AStream
207 var err error
208 retry.Run(ctx,
209 func() error {
210 s2AStream, err = createStream(ctx, c.s2av2Address, c.transportCreds, c.getS2AStream)
211 return err
212 })
213 if err != nil {
214 grpclog.Infof("Failed to connect to S2Av2: %v", err)
215 return nil, nil, err
216 }
217 defer s2AStream.CloseSend()
218 if grpclog.V(1) {
219 grpclog.Infof("Connected to S2Av2.")
220 }
221
222 var tokenManager tokenmanager.AccessTokenManager
223 if c.tokenManager == nil {
224 tokenManager = nil
225 } else {
226 tokenManager = *c.tokenManager
227 }
228
229 var config *tls.Config
230 retry.Run(ctx,
231 func() error {
232 config, err = tlsconfigstore.GetTLSConfigurationForServer(s2AStream, tokenManager, c.localIdentities, c.verificationMode)
233 return err
234 })
235 if err != nil {
236 grpclog.Infof("Failed to get server TLS config from S2Av2: %v", err)
237 return nil, nil, err
238 }
239 if grpclog.V(1) {
240 grpclog.Infof("Got server TLS config from S2Av2.")
241 }
242
243 creds := credentials.NewTLS(config)
244 conn, authInfo, err := creds.ServerHandshake(rawConn)
245 if err != nil {
246 grpclog.Infof("Failed to do server handshake using S2Av2: %v", err)
247 return nil, nil, err
248 }
249 return conn, authInfo, err
250 }
251
252 // Info returns protocol info of s2av2TransportCreds.
253 func (c *s2av2TransportCreds) Info() credentials.ProtocolInfo {
254 return *c.info
255 }
256
257 // Clone makes a deep copy of s2av2TransportCreds.
258 func (c *s2av2TransportCreds) Clone() credentials.TransportCredentials {
259 info := *c.info
260 serverName := c.serverName
261 fallbackClientHandshake := c.fallbackClientHandshake
262
263 s2av2Address := c.s2av2Address
264 var tokenManager tokenmanager.AccessTokenManager
265 if c.tokenManager == nil {
266 tokenManager = nil
267 } else {
268 tokenManager = *c.tokenManager
269 }
270 verificationMode := c.verificationMode
271 var localIdentity *commonpb.Identity
272 if c.localIdentity != nil {
273 localIdentity = proto.Clone(c.localIdentity).(*commonpb.Identity)
274 }
275 var localIdentities []*commonpb.Identity
276 if c.localIdentities != nil {
277 localIdentities = make([]*commonpb.Identity, len(c.localIdentities))
278 for i, localIdentity := range c.localIdentities {
279 localIdentities[i] = proto.Clone(localIdentity).(*commonpb.Identity)
280 }
281 }
282 creds := &s2av2TransportCreds{
283 info: &info,
284 isClient: c.isClient,
285 serverName: serverName,
286 fallbackClientHandshake: fallbackClientHandshake,
287 s2av2Address: s2av2Address,
288 localIdentity: localIdentity,
289 localIdentities: localIdentities,
290 verificationMode: verificationMode,
291 }
292 if c.tokenManager == nil {
293 creds.tokenManager = nil
294 } else {
295 creds.tokenManager = &tokenManager
296 }
297 return creds
298 }
299
300 // NewClientTLSConfig returns a tls.Config instance that uses S2Av2 to establish a TLS connection as
301 // a client. The tls.Config MUST only be used to establish a single TLS connection.
302 func NewClientTLSConfig(
303 ctx context.Context,
304 s2av2Address string,
305 transportCreds credentials.TransportCredentials,
306 tokenManager tokenmanager.AccessTokenManager,
307 verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode,
308 serverName string,
309 serverAuthorizationPolicy []byte,
310 getStream stream.GetS2AStream) (*tls.Config, error) {
311 s2AStream, err := createStream(ctx, s2av2Address, transportCreds, getStream)
312 if err != nil {
313 grpclog.Infof("Failed to connect to S2Av2: %v", err)
314 return nil, err
315 }
316
317 return tlsconfigstore.GetTLSConfigurationForClient(removeServerNamePort(serverName), s2AStream, tokenManager, nil, verificationMode, serverAuthorizationPolicy)
318 }
319
320 // OverrideServerName sets the ServerName in the s2av2TransportCreds protocol
321 // info. The ServerName MUST be a hostname.
322 func (c *s2av2TransportCreds) OverrideServerName(serverNameOverride string) error {
323 serverName := removeServerNamePort(serverNameOverride)
324 c.info.ServerName = serverName
325 c.serverName = serverName
326 return nil
327 }
328
329 // Remove the trailing port from server name.
330 func removeServerNamePort(serverName string) string {
331 name, _, err := net.SplitHostPort(serverName)
332 if err != nil {
333 name = serverName
334 }
335 return name
336 }
337
338 type s2AGrpcStream struct {
339 stream s2av2pb.S2AService_SetUpSessionClient
340 }
341
342 func (x s2AGrpcStream) Send(m *s2av2pb.SessionReq) error {
343 return x.stream.Send(m)
344 }
345
346 func (x s2AGrpcStream) Recv() (*s2av2pb.SessionResp, error) {
347 return x.stream.Recv()
348 }
349
350 func (x s2AGrpcStream) CloseSend() error {
351 return x.stream.CloseSend()
352 }
353
354 func createStream(ctx context.Context, s2av2Address string, transportCreds credentials.TransportCredentials, getS2AStream stream.GetS2AStream) (stream.S2AStream, error) {
355 if getS2AStream != nil {
356 return getS2AStream(ctx, s2av2Address)
357 }
358 // TODO(rmehta19): Consider whether to close the connection to S2Av2.
359 conn, err := service.Dial(ctx, s2av2Address, transportCreds)
360 if err != nil {
361 return nil, err
362 }
363 client := s2av2pb.NewS2AServiceClient(conn)
364 gRPCStream, err := client.SetUpSession(ctx, []grpc.CallOption{}...)
365 if err != nil {
366 return nil, err
367 }
368 return &s2AGrpcStream{
369 stream: gRPCStream,
370 }, nil
371 }
372
373 // GetS2ATimeout returns the timeout enforced on the connection to the S2A service for handshake.
374 func GetS2ATimeout() time.Duration {
375 timeout, err := time.ParseDuration(os.Getenv(s2aTimeoutEnv))
376 if err != nil {
377 return defaultS2ATimeout
378 }
379 return timeout
380 }
381