oauth2adapt.go raw

   1  // Copyright 2023 Google LLC
   2  //
   3  // Licensed under the Apache License, Version 2.0 (the "License");
   4  // you may not use this file except in compliance with the License.
   5  // You may obtain a copy of the License at
   6  //
   7  //      http://www.apache.org/licenses/LICENSE-2.0
   8  //
   9  // Unless required by applicable law or agreed to in writing, software
  10  // distributed under the License is distributed on an "AS IS" BASIS,
  11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12  // See the License for the specific language governing permissions and
  13  // limitations under the License.
  14  
  15  // Package oauth2adapt helps converts types used in [cloud.google.com/go/auth]
  16  // and [golang.org/x/oauth2].
  17  package oauth2adapt
  18  
  19  import (
  20  	"context"
  21  	"encoding/json"
  22  	"errors"
  23  
  24  	"cloud.google.com/go/auth"
  25  	"golang.org/x/oauth2"
  26  	"golang.org/x/oauth2/google"
  27  )
  28  
  29  const (
  30  	oauth2TokenSourceKey    = "oauth2.google.tokenSource"
  31  	oauth2ServiceAccountKey = "oauth2.google.serviceAccount"
  32  	authTokenSourceKey      = "auth.google.tokenSource"
  33  	authServiceAccountKey   = "auth.google.serviceAccount"
  34  )
  35  
  36  // TokenProviderFromTokenSource converts any [golang.org/x/oauth2.TokenSource]
  37  // into a [cloud.google.com/go/auth.TokenProvider].
  38  func TokenProviderFromTokenSource(ts oauth2.TokenSource) auth.TokenProvider {
  39  	return &tokenProviderAdapter{ts: ts}
  40  }
  41  
  42  type tokenProviderAdapter struct {
  43  	ts oauth2.TokenSource
  44  }
  45  
  46  // Token fulfills the [cloud.google.com/go/auth.TokenProvider] interface. It
  47  // is a light wrapper around the underlying TokenSource.
  48  func (tp *tokenProviderAdapter) Token(context.Context) (*auth.Token, error) {
  49  	tok, err := tp.ts.Token()
  50  	if err != nil {
  51  		var err2 *oauth2.RetrieveError
  52  		if ok := errors.As(err, &err2); ok {
  53  			return nil, AuthErrorFromRetrieveError(err2)
  54  		}
  55  		return nil, err
  56  	}
  57  	// Preserve compute token metadata, for both types of tokens.
  58  	metadata := map[string]interface{}{}
  59  	if val, ok := tok.Extra(oauth2TokenSourceKey).(string); ok {
  60  		metadata[authTokenSourceKey] = val
  61  		metadata[oauth2TokenSourceKey] = val
  62  	}
  63  	if val, ok := tok.Extra(oauth2ServiceAccountKey).(string); ok {
  64  		metadata[authServiceAccountKey] = val
  65  		metadata[oauth2ServiceAccountKey] = val
  66  	}
  67  	return &auth.Token{
  68  		Value:    tok.AccessToken,
  69  		Type:     tok.Type(),
  70  		Expiry:   tok.Expiry,
  71  		Metadata: metadata,
  72  	}, nil
  73  }
  74  
  75  // TokenSourceFromTokenProvider converts any
  76  // [cloud.google.com/go/auth.TokenProvider] into a
  77  // [golang.org/x/oauth2.TokenSource].
  78  func TokenSourceFromTokenProvider(tp auth.TokenProvider) oauth2.TokenSource {
  79  	return &tokenSourceAdapter{tp: tp}
  80  }
  81  
  82  type tokenSourceAdapter struct {
  83  	tp auth.TokenProvider
  84  }
  85  
  86  // Token fulfills the [golang.org/x/oauth2.TokenSource] interface. It
  87  // is a light wrapper around the underlying TokenProvider.
  88  func (ts *tokenSourceAdapter) Token() (*oauth2.Token, error) {
  89  	tok, err := ts.tp.Token(context.Background())
  90  	if err != nil {
  91  		var err2 *auth.Error
  92  		if ok := errors.As(err, &err2); ok {
  93  			return nil, AddRetrieveErrorToAuthError(err2)
  94  		}
  95  		return nil, err
  96  	}
  97  	tok2 := &oauth2.Token{
  98  		AccessToken: tok.Value,
  99  		TokenType:   tok.Type,
 100  		Expiry:      tok.Expiry,
 101  	}
 102  	// Preserve token metadata.
 103  	m := tok.Metadata
 104  	if m != nil {
 105  		// Copy map to avoid concurrent map writes error (#11161).
 106  		metadata := make(map[string]interface{}, len(m)+2)
 107  		for k, v := range m {
 108  			metadata[k] = v
 109  		}
 110  		// Append compute token metadata in converted form.
 111  		if val, ok := metadata[authTokenSourceKey].(string); ok && val != "" {
 112  			metadata[oauth2TokenSourceKey] = val
 113  		}
 114  		if val, ok := metadata[authServiceAccountKey].(string); ok && val != "" {
 115  			metadata[oauth2ServiceAccountKey] = val
 116  		}
 117  		tok2 = tok2.WithExtra(metadata)
 118  	}
 119  	return tok2, nil
 120  }
 121  
 122  // AuthCredentialsFromOauth2Credentials converts a [golang.org/x/oauth2/google.Credentials]
 123  // to a [cloud.google.com/go/auth.Credentials].
 124  func AuthCredentialsFromOauth2Credentials(creds *google.Credentials) *auth.Credentials {
 125  	if creds == nil {
 126  		return nil
 127  	}
 128  	return auth.NewCredentials(&auth.CredentialsOptions{
 129  		TokenProvider: TokenProviderFromTokenSource(creds.TokenSource),
 130  		JSON:          creds.JSON,
 131  		ProjectIDProvider: auth.CredentialsPropertyFunc(func(ctx context.Context) (string, error) {
 132  			return creds.ProjectID, nil
 133  		}),
 134  		UniverseDomainProvider: auth.CredentialsPropertyFunc(func(ctx context.Context) (string, error) {
 135  			return creds.GetUniverseDomain()
 136  		}),
 137  	})
 138  }
 139  
 140  // Oauth2CredentialsFromAuthCredentials converts a [cloud.google.com/go/auth.Credentials]
 141  // to a [golang.org/x/oauth2/google.Credentials].
 142  func Oauth2CredentialsFromAuthCredentials(creds *auth.Credentials) *google.Credentials {
 143  	if creds == nil {
 144  		return nil
 145  	}
 146  	// Throw away errors as old credentials are not request aware. Also, no
 147  	// network requests are currently happening for this use case.
 148  	projectID, _ := creds.ProjectID(context.Background())
 149  
 150  	return &google.Credentials{
 151  		TokenSource: TokenSourceFromTokenProvider(creds.TokenProvider),
 152  		ProjectID:   projectID,
 153  		JSON:        creds.JSON(),
 154  		UniverseDomainProvider: func() (string, error) {
 155  			return creds.UniverseDomain(context.Background())
 156  		},
 157  	}
 158  }
 159  
 160  type oauth2Error struct {
 161  	ErrorCode        string `json:"error"`
 162  	ErrorDescription string `json:"error_description"`
 163  	ErrorURI         string `json:"error_uri"`
 164  }
 165  
 166  // AddRetrieveErrorToAuthError returns the same error provided and adds a
 167  // [golang.org/x/oauth2.RetrieveError] to the error chain by setting the `Err` field on the
 168  // [cloud.google.com/go/auth.Error].
 169  func AddRetrieveErrorToAuthError(err *auth.Error) *auth.Error {
 170  	if err == nil {
 171  		return nil
 172  	}
 173  	e := &oauth2.RetrieveError{
 174  		Response: err.Response,
 175  		Body:     err.Body,
 176  	}
 177  	err.Err = e
 178  	if len(err.Body) > 0 {
 179  		var oErr oauth2Error
 180  		// ignore the error as it only fills in extra details
 181  		json.Unmarshal(err.Body, &oErr)
 182  		e.ErrorCode = oErr.ErrorCode
 183  		e.ErrorDescription = oErr.ErrorDescription
 184  		e.ErrorURI = oErr.ErrorURI
 185  	}
 186  	return err
 187  }
 188  
 189  // AuthErrorFromRetrieveError returns an [cloud.google.com/go/auth.Error] that
 190  // wraps the provided [golang.org/x/oauth2.RetrieveError].
 191  func AuthErrorFromRetrieveError(err *oauth2.RetrieveError) *auth.Error {
 192  	if err == nil {
 193  		return nil
 194  	}
 195  	return &auth.Error{
 196  		Response: err.Response,
 197  		Body:     err.Body,
 198  		Err:      err,
 199  	}
 200  }
 201