1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 15 // Package auth provides utilities for managing Google Cloud credentials,
16 // including functionality for creating, caching, and refreshing OAuth2 tokens.
17 // It offers customizable options for different OAuth2 flows, such as 2-legged
18 // (2LO) and 3-legged (3LO) OAuth, along with support for PKCE and automatic
19 // token management.
20 package auth
21 22 import (
23 "context"
24 "encoding/json"
25 "errors"
26 "fmt"
27 "log/slog"
28 "net/http"
29 "net/url"
30 "strings"
31 "sync"
32 "time"
33 34 "cloud.google.com/go/auth/internal"
35 "cloud.google.com/go/auth/internal/jwt"
36 "github.com/googleapis/gax-go/v2/internallog"
37 )
38 39 const (
40 // Parameter keys for AuthCodeURL method to support PKCE.
41 codeChallengeKey = "code_challenge"
42 codeChallengeMethodKey = "code_challenge_method"
43 44 // Parameter key for Exchange method to support PKCE.
45 codeVerifierKey = "code_verifier"
46 47 // 3 minutes and 45 seconds before expiration. The shortest MDS cache is 4 minutes,
48 // so we give it 15 seconds to refresh it's cache before attempting to refresh a token.
49 defaultExpiryDelta = 225 * time.Second
50 51 universeDomainDefault = "googleapis.com"
52 )
53 54 // tokenState represents different states for a [Token].
55 type tokenState int
56 57 const (
58 // fresh indicates that the [Token] is valid. It is not expired or close to
59 // expired, or the token has no expiry.
60 fresh tokenState = iota
61 // stale indicates that the [Token] is close to expired, and should be
62 // refreshed. The token can be used normally.
63 stale
64 // invalid indicates that the [Token] is expired or invalid. The token
65 // cannot be used for a normal operation.
66 invalid
67 )
68 69 var (
70 defaultGrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
71 defaultHeader = &jwt.Header{Algorithm: jwt.HeaderAlgRSA256, Type: jwt.HeaderType}
72 73 // for testing
74 timeNow = time.Now
75 )
76 77 // TokenProvider specifies an interface for anything that can return a token.
78 type TokenProvider interface {
79 // Token returns a Token or an error.
80 // The Token returned must be safe to use
81 // concurrently.
82 // The returned Token must not be modified.
83 // The context provided must be sent along to any requests that are made in
84 // the implementing code.
85 Token(context.Context) (*Token, error)
86 }
87 88 // Token holds the credential token used to authorized requests. All fields are
89 // considered read-only.
90 type Token struct {
91 // Value is the token used to authorize requests. It is usually an access
92 // token but may be other types of tokens such as ID tokens in some flows.
93 Value string
94 // Type is the type of token Value is. If uninitialized, it should be
95 // assumed to be a "Bearer" token.
96 Type string
97 // Expiry is the time the token is set to expire.
98 Expiry time.Time
99 // Metadata may include, but is not limited to, the body of the token
100 // response returned by the server.
101 Metadata map[string]interface{} // TODO(codyoss): maybe make a method to flatten metadata to avoid []string for url.Values
102 }
103 104 // IsValid reports that a [Token] is non-nil, has a [Token.Value], and has not
105 // expired. A token is considered expired if [Token.Expiry] has passed or will
106 // pass in the next 225 seconds.
107 func (t *Token) IsValid() bool {
108 return t.isValidWithEarlyExpiry(defaultExpiryDelta)
109 }
110 111 // MetadataString is a convenience method for accessing string values in the
112 // token's metadata. Returns an empty string if the metadata is nil or the value
113 // for the given key cannot be cast to a string.
114 func (t *Token) MetadataString(k string) string {
115 if t.Metadata == nil {
116 return ""
117 }
118 s, ok := t.Metadata[k].(string)
119 if !ok {
120 return ""
121 }
122 return s
123 }
124 125 func (t *Token) isValidWithEarlyExpiry(earlyExpiry time.Duration) bool {
126 if t.isEmpty() {
127 return false
128 }
129 if t.Expiry.IsZero() {
130 return true
131 }
132 return !t.Expiry.Round(0).Add(-earlyExpiry).Before(timeNow())
133 }
134 135 func (t *Token) isEmpty() bool {
136 return t == nil || t.Value == ""
137 }
138 139 // Credentials holds Google credentials, including
140 // [Application Default Credentials].
141 //
142 // [Application Default Credentials]: https://developers.google.com/accounts/docs/application-default-credentials
143 type Credentials struct {
144 json []byte
145 projectID CredentialsPropertyProvider
146 quotaProjectID CredentialsPropertyProvider
147 // universeDomain is the default service domain for a given Cloud universe.
148 universeDomain CredentialsPropertyProvider
149 150 TokenProvider
151 }
152 153 // JSON returns the bytes associated with the the file used to source
154 // credentials if one was used.
155 func (c *Credentials) JSON() []byte {
156 return c.json
157 }
158 159 // ProjectID returns the associated project ID from the underlying file or
160 // environment.
161 func (c *Credentials) ProjectID(ctx context.Context) (string, error) {
162 if c.projectID == nil {
163 return internal.GetProjectID(c.json, ""), nil
164 }
165 v, err := c.projectID.GetProperty(ctx)
166 if err != nil {
167 return "", err
168 }
169 return internal.GetProjectID(c.json, v), nil
170 }
171 172 // QuotaProjectID returns the associated quota project ID from the underlying
173 // file or environment.
174 func (c *Credentials) QuotaProjectID(ctx context.Context) (string, error) {
175 if c.quotaProjectID == nil {
176 return internal.GetQuotaProject(c.json, ""), nil
177 }
178 v, err := c.quotaProjectID.GetProperty(ctx)
179 if err != nil {
180 return "", err
181 }
182 return internal.GetQuotaProject(c.json, v), nil
183 }
184 185 // UniverseDomain returns the default service domain for a given Cloud universe.
186 // The default value is "googleapis.com".
187 func (c *Credentials) UniverseDomain(ctx context.Context) (string, error) {
188 if c.universeDomain == nil {
189 return universeDomainDefault, nil
190 }
191 v, err := c.universeDomain.GetProperty(ctx)
192 if err != nil {
193 return "", err
194 }
195 if v == "" {
196 return universeDomainDefault, nil
197 }
198 return v, err
199 }
200 201 // CredentialsPropertyProvider provides an implementation to fetch a property
202 // value for [Credentials].
203 type CredentialsPropertyProvider interface {
204 GetProperty(context.Context) (string, error)
205 }
206 207 // CredentialsPropertyFunc is a type adapter to allow the use of ordinary
208 // functions as a [CredentialsPropertyProvider].
209 type CredentialsPropertyFunc func(context.Context) (string, error)
210 211 // GetProperty loads the properly value provided the given context.
212 func (p CredentialsPropertyFunc) GetProperty(ctx context.Context) (string, error) {
213 return p(ctx)
214 }
215 216 // CredentialsOptions are used to configure [Credentials].
217 type CredentialsOptions struct {
218 // TokenProvider is a means of sourcing a token for the credentials. Required.
219 TokenProvider TokenProvider
220 // JSON is the raw contents of the credentials file if sourced from a file.
221 JSON []byte
222 // ProjectIDProvider resolves the project ID associated with the
223 // credentials.
224 ProjectIDProvider CredentialsPropertyProvider
225 // QuotaProjectIDProvider resolves the quota project ID associated with the
226 // credentials.
227 QuotaProjectIDProvider CredentialsPropertyProvider
228 // UniverseDomainProvider resolves the universe domain with the credentials.
229 UniverseDomainProvider CredentialsPropertyProvider
230 }
231 232 // NewCredentials returns new [Credentials] from the provided options.
233 func NewCredentials(opts *CredentialsOptions) *Credentials {
234 creds := &Credentials{
235 TokenProvider: opts.TokenProvider,
236 json: opts.JSON,
237 projectID: opts.ProjectIDProvider,
238 quotaProjectID: opts.QuotaProjectIDProvider,
239 universeDomain: opts.UniverseDomainProvider,
240 }
241 242 return creds
243 }
244 245 // CachedTokenProviderOptions provides options for configuring a cached
246 // [TokenProvider].
247 type CachedTokenProviderOptions struct {
248 // DisableAutoRefresh makes the TokenProvider always return the same token,
249 // even if it is expired. The default is false. Optional.
250 DisableAutoRefresh bool
251 // ExpireEarly configures the amount of time before a token expires, that it
252 // should be refreshed. If unset, the default value is 3 minutes and 45
253 // seconds. Optional.
254 ExpireEarly time.Duration
255 // DisableAsyncRefresh configures a synchronous workflow that refreshes
256 // tokens in a blocking manner. The default is false. Optional.
257 DisableAsyncRefresh bool
258 }
259 260 func (ctpo *CachedTokenProviderOptions) autoRefresh() bool {
261 if ctpo == nil {
262 return true
263 }
264 return !ctpo.DisableAutoRefresh
265 }
266 267 func (ctpo *CachedTokenProviderOptions) expireEarly() time.Duration {
268 if ctpo == nil || ctpo.ExpireEarly == 0 {
269 return defaultExpiryDelta
270 }
271 return ctpo.ExpireEarly
272 }
273 274 func (ctpo *CachedTokenProviderOptions) blockingRefresh() bool {
275 if ctpo == nil {
276 return false
277 }
278 return ctpo.DisableAsyncRefresh
279 }
280 281 // NewCachedTokenProvider wraps a [TokenProvider] to cache the tokens returned
282 // by the underlying provider. By default it will refresh tokens asynchronously
283 // a few minutes before they expire.
284 func NewCachedTokenProvider(tp TokenProvider, opts *CachedTokenProviderOptions) TokenProvider {
285 if ctp, ok := tp.(*cachedTokenProvider); ok {
286 return ctp
287 }
288 return &cachedTokenProvider{
289 tp: tp,
290 autoRefresh: opts.autoRefresh(),
291 expireEarly: opts.expireEarly(),
292 blockingRefresh: opts.blockingRefresh(),
293 }
294 }
295 296 type cachedTokenProvider struct {
297 tp TokenProvider
298 autoRefresh bool
299 expireEarly time.Duration
300 blockingRefresh bool
301 302 mu sync.Mutex
303 cachedToken *Token
304 // isRefreshRunning ensures that the non-blocking refresh will only be
305 // attempted once, even if multiple callers enter the Token method.
306 isRefreshRunning bool
307 // isRefreshErr ensures that the non-blocking refresh will only be attempted
308 // once per refresh window if an error is encountered.
309 isRefreshErr bool
310 }
311 312 func (c *cachedTokenProvider) Token(ctx context.Context) (*Token, error) {
313 if c.blockingRefresh {
314 return c.tokenBlocking(ctx)
315 }
316 return c.tokenNonBlocking(ctx)
317 }
318 319 func (c *cachedTokenProvider) tokenNonBlocking(ctx context.Context) (*Token, error) {
320 switch c.tokenState() {
321 case fresh:
322 c.mu.Lock()
323 defer c.mu.Unlock()
324 return c.cachedToken, nil
325 case stale:
326 // Call tokenAsync with a new Context because the user-provided context
327 // may have a short timeout incompatible with async token refresh.
328 c.tokenAsync(context.Background())
329 // Return the stale token immediately to not block customer requests to Cloud services.
330 c.mu.Lock()
331 defer c.mu.Unlock()
332 return c.cachedToken, nil
333 default: // invalid
334 return c.tokenBlocking(ctx)
335 }
336 }
337 338 // tokenState reports the token's validity.
339 func (c *cachedTokenProvider) tokenState() tokenState {
340 c.mu.Lock()
341 defer c.mu.Unlock()
342 t := c.cachedToken
343 now := timeNow()
344 if t == nil || t.Value == "" {
345 return invalid
346 } else if t.Expiry.IsZero() {
347 return fresh
348 } else if now.After(t.Expiry.Round(0)) {
349 return invalid
350 } else if now.After(t.Expiry.Round(0).Add(-c.expireEarly)) {
351 return stale
352 }
353 return fresh
354 }
355 356 // tokenAsync uses a bool to ensure that only one non-blocking token refresh
357 // happens at a time, even if multiple callers have entered this function
358 // concurrently. This avoids creating an arbitrary number of concurrent
359 // goroutines. Retries should be attempted and managed within the Token method.
360 // If the refresh attempt fails, no further attempts are made until the refresh
361 // window expires and the token enters the invalid state, at which point the
362 // blocking call to Token should likely return the same error on the main goroutine.
363 func (c *cachedTokenProvider) tokenAsync(ctx context.Context) {
364 fn := func() {
365 t, err := c.tp.Token(ctx)
366 c.mu.Lock()
367 defer c.mu.Unlock()
368 c.isRefreshRunning = false
369 if err != nil {
370 // Discard errors from the non-blocking refresh, but prevent further
371 // attempts.
372 c.isRefreshErr = true
373 return
374 }
375 c.cachedToken = t
376 }
377 c.mu.Lock()
378 defer c.mu.Unlock()
379 if !c.isRefreshRunning && !c.isRefreshErr {
380 c.isRefreshRunning = true
381 go fn()
382 }
383 }
384 385 func (c *cachedTokenProvider) tokenBlocking(ctx context.Context) (*Token, error) {
386 c.mu.Lock()
387 defer c.mu.Unlock()
388 c.isRefreshErr = false
389 if c.cachedToken.IsValid() || (!c.autoRefresh && !c.cachedToken.isEmpty()) {
390 return c.cachedToken, nil
391 }
392 t, err := c.tp.Token(ctx)
393 if err != nil {
394 return nil, err
395 }
396 c.cachedToken = t
397 return t, nil
398 }
399 400 // Error is a error associated with retrieving a [Token]. It can hold useful
401 // additional details for debugging.
402 type Error struct {
403 // Response is the HTTP response associated with error. The body will always
404 // be already closed and consumed.
405 Response *http.Response
406 // Body is the HTTP response body.
407 Body []byte
408 // Err is the underlying wrapped error.
409 Err error
410 411 // code returned in the token response
412 code string
413 // description returned in the token response
414 description string
415 // uri returned in the token response
416 uri string
417 }
418 419 func (e *Error) Error() string {
420 if e.code != "" {
421 s := fmt.Sprintf("auth: %q", e.code)
422 if e.description != "" {
423 s += fmt.Sprintf(" %q", e.description)
424 }
425 if e.uri != "" {
426 s += fmt.Sprintf(" %q", e.uri)
427 }
428 return s
429 }
430 return fmt.Sprintf("auth: cannot fetch token: %v\nResponse: %s", e.Response.StatusCode, e.Body)
431 }
432 433 // Temporary returns true if the error is considered temporary and may be able
434 // to be retried.
435 func (e *Error) Temporary() bool {
436 if e.Response == nil {
437 return false
438 }
439 sc := e.Response.StatusCode
440 return sc == http.StatusInternalServerError || sc == http.StatusServiceUnavailable || sc == http.StatusRequestTimeout || sc == http.StatusTooManyRequests
441 }
442 443 func (e *Error) Unwrap() error {
444 return e.Err
445 }
446 447 // Style describes how the token endpoint wants to receive the ClientID and
448 // ClientSecret.
449 type Style int
450 451 const (
452 // StyleUnknown means the value has not been initiated. Sending this in
453 // a request will cause the token exchange to fail.
454 StyleUnknown Style = iota
455 // StyleInParams sends client info in the body of a POST request.
456 StyleInParams
457 // StyleInHeader sends client info using Basic Authorization header.
458 StyleInHeader
459 )
460 461 // Options2LO is the configuration settings for doing a 2-legged JWT OAuth2 flow.
462 type Options2LO struct {
463 // Email is the OAuth2 client ID. This value is set as the "iss" in the
464 // JWT.
465 Email string
466 // PrivateKey contains the contents of an RSA private key or the
467 // contents of a PEM file that contains a private key. It is used to sign
468 // the JWT created.
469 PrivateKey []byte
470 // TokenURL is th URL the JWT is sent to. Required.
471 TokenURL string
472 // PrivateKeyID is the ID of the key used to sign the JWT. It is used as the
473 // "kid" in the JWT header. Optional.
474 PrivateKeyID string
475 // Subject is the used for to impersonate a user. It is used as the "sub" in
476 // the JWT.m Optional.
477 Subject string
478 // Scopes specifies requested permissions for the token. Optional.
479 Scopes []string
480 // Expires specifies the lifetime of the token. Optional.
481 Expires time.Duration
482 // Audience specifies the "aud" in the JWT. Optional.
483 Audience string
484 // PrivateClaims allows specifying any custom claims for the JWT. Optional.
485 PrivateClaims map[string]interface{}
486 // UniverseDomain is the default service domain for a given Cloud universe.
487 UniverseDomain string
488 489 // Client is the client to be used to make the underlying token requests.
490 // Optional.
491 Client *http.Client
492 // UseIDToken requests that the token returned be an ID token if one is
493 // returned from the server. Optional.
494 UseIDToken bool
495 // Logger is used for debug logging. If provided, logging will be enabled
496 // at the loggers configured level. By default logging is disabled unless
497 // enabled by setting GOOGLE_SDK_GO_LOGGING_LEVEL in which case a default
498 // logger will be used. Optional.
499 Logger *slog.Logger
500 }
501 502 func (o *Options2LO) client() *http.Client {
503 if o.Client != nil {
504 return o.Client
505 }
506 return internal.DefaultClient()
507 }
508 509 func (o *Options2LO) validate() error {
510 if o == nil {
511 return errors.New("auth: options must be provided")
512 }
513 if o.Email == "" {
514 return errors.New("auth: email must be provided")
515 }
516 if len(o.PrivateKey) == 0 {
517 return errors.New("auth: private key must be provided")
518 }
519 if o.TokenURL == "" {
520 return errors.New("auth: token URL must be provided")
521 }
522 return nil
523 }
524 525 // New2LOTokenProvider returns a [TokenProvider] from the provided options.
526 func New2LOTokenProvider(opts *Options2LO) (TokenProvider, error) {
527 if err := opts.validate(); err != nil {
528 return nil, err
529 }
530 return tokenProvider2LO{opts: opts, Client: opts.client(), logger: internallog.New(opts.Logger)}, nil
531 }
532 533 type tokenProvider2LO struct {
534 opts *Options2LO
535 Client *http.Client
536 logger *slog.Logger
537 }
538 539 func (tp tokenProvider2LO) Token(ctx context.Context) (*Token, error) {
540 pk, err := internal.ParseKey(tp.opts.PrivateKey)
541 if err != nil {
542 return nil, err
543 }
544 claimSet := &jwt.Claims{
545 Iss: tp.opts.Email,
546 Scope: strings.Join(tp.opts.Scopes, " "),
547 Aud: tp.opts.TokenURL,
548 AdditionalClaims: tp.opts.PrivateClaims,
549 Sub: tp.opts.Subject,
550 }
551 if t := tp.opts.Expires; t > 0 {
552 claimSet.Exp = time.Now().Add(t).Unix()
553 }
554 if aud := tp.opts.Audience; aud != "" {
555 claimSet.Aud = aud
556 }
557 h := *defaultHeader
558 h.KeyID = tp.opts.PrivateKeyID
559 payload, err := jwt.EncodeJWS(&h, claimSet, pk)
560 if err != nil {
561 return nil, err
562 }
563 v := url.Values{}
564 v.Set("grant_type", defaultGrantType)
565 v.Set("assertion", payload)
566 req, err := http.NewRequestWithContext(ctx, "POST", tp.opts.TokenURL, strings.NewReader(v.Encode()))
567 if err != nil {
568 return nil, err
569 }
570 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
571 tp.logger.DebugContext(ctx, "2LO token request", "request", internallog.HTTPRequest(req, []byte(v.Encode())))
572 resp, body, err := internal.DoRequest(tp.Client, req)
573 if err != nil {
574 return nil, fmt.Errorf("auth: cannot fetch token: %w", err)
575 }
576 tp.logger.DebugContext(ctx, "2LO token response", "response", internallog.HTTPResponse(resp, body))
577 if c := resp.StatusCode; c < http.StatusOK || c >= http.StatusMultipleChoices {
578 return nil, &Error{
579 Response: resp,
580 Body: body,
581 }
582 }
583 // tokenRes is the JSON response body.
584 var tokenRes struct {
585 AccessToken string `json:"access_token"`
586 TokenType string `json:"token_type"`
587 IDToken string `json:"id_token"`
588 ExpiresIn int64 `json:"expires_in"`
589 }
590 if err := json.Unmarshal(body, &tokenRes); err != nil {
591 return nil, fmt.Errorf("auth: cannot fetch token: %w", err)
592 }
593 token := &Token{
594 Value: tokenRes.AccessToken,
595 Type: tokenRes.TokenType,
596 }
597 token.Metadata = make(map[string]interface{})
598 json.Unmarshal(body, &token.Metadata) // no error checks for optional fields
599 600 if secs := tokenRes.ExpiresIn; secs > 0 {
601 token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
602 }
603 if v := tokenRes.IDToken; v != "" {
604 // decode returned id token to get expiry
605 claimSet, err := jwt.DecodeJWS(v)
606 if err != nil {
607 return nil, fmt.Errorf("auth: error decoding JWT token: %w", err)
608 }
609 token.Expiry = time.Unix(claimSet.Exp, 0)
610 }
611 if tp.opts.UseIDToken {
612 if tokenRes.IDToken == "" {
613 return nil, fmt.Errorf("auth: response doesn't have JWT token")
614 }
615 token.Value = tokenRes.IDToken
616 }
617 return token, nil
618 }
619