oauth2adapt.go raw
1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Package oauth2adapt helps converts types used in [cloud.google.com/go/auth]
16 // and [golang.org/x/oauth2].
17 package oauth2adapt
18
19 import (
20 "context"
21 "encoding/json"
22 "errors"
23
24 "cloud.google.com/go/auth"
25 "golang.org/x/oauth2"
26 "golang.org/x/oauth2/google"
27 )
28
29 const (
30 oauth2TokenSourceKey = "oauth2.google.tokenSource"
31 oauth2ServiceAccountKey = "oauth2.google.serviceAccount"
32 authTokenSourceKey = "auth.google.tokenSource"
33 authServiceAccountKey = "auth.google.serviceAccount"
34 )
35
36 // TokenProviderFromTokenSource converts any [golang.org/x/oauth2.TokenSource]
37 // into a [cloud.google.com/go/auth.TokenProvider].
38 func TokenProviderFromTokenSource(ts oauth2.TokenSource) auth.TokenProvider {
39 return &tokenProviderAdapter{ts: ts}
40 }
41
42 type tokenProviderAdapter struct {
43 ts oauth2.TokenSource
44 }
45
46 // Token fulfills the [cloud.google.com/go/auth.TokenProvider] interface. It
47 // is a light wrapper around the underlying TokenSource.
48 func (tp *tokenProviderAdapter) Token(context.Context) (*auth.Token, error) {
49 tok, err := tp.ts.Token()
50 if err != nil {
51 var err2 *oauth2.RetrieveError
52 if ok := errors.As(err, &err2); ok {
53 return nil, AuthErrorFromRetrieveError(err2)
54 }
55 return nil, err
56 }
57 // Preserve compute token metadata, for both types of tokens.
58 metadata := map[string]interface{}{}
59 if val, ok := tok.Extra(oauth2TokenSourceKey).(string); ok {
60 metadata[authTokenSourceKey] = val
61 metadata[oauth2TokenSourceKey] = val
62 }
63 if val, ok := tok.Extra(oauth2ServiceAccountKey).(string); ok {
64 metadata[authServiceAccountKey] = val
65 metadata[oauth2ServiceAccountKey] = val
66 }
67 return &auth.Token{
68 Value: tok.AccessToken,
69 Type: tok.Type(),
70 Expiry: tok.Expiry,
71 Metadata: metadata,
72 }, nil
73 }
74
75 // TokenSourceFromTokenProvider converts any
76 // [cloud.google.com/go/auth.TokenProvider] into a
77 // [golang.org/x/oauth2.TokenSource].
78 func TokenSourceFromTokenProvider(tp auth.TokenProvider) oauth2.TokenSource {
79 return &tokenSourceAdapter{tp: tp}
80 }
81
82 type tokenSourceAdapter struct {
83 tp auth.TokenProvider
84 }
85
86 // Token fulfills the [golang.org/x/oauth2.TokenSource] interface. It
87 // is a light wrapper around the underlying TokenProvider.
88 func (ts *tokenSourceAdapter) Token() (*oauth2.Token, error) {
89 tok, err := ts.tp.Token(context.Background())
90 if err != nil {
91 var err2 *auth.Error
92 if ok := errors.As(err, &err2); ok {
93 return nil, AddRetrieveErrorToAuthError(err2)
94 }
95 return nil, err
96 }
97 tok2 := &oauth2.Token{
98 AccessToken: tok.Value,
99 TokenType: tok.Type,
100 Expiry: tok.Expiry,
101 }
102 // Preserve token metadata.
103 m := tok.Metadata
104 if m != nil {
105 // Copy map to avoid concurrent map writes error (#11161).
106 metadata := make(map[string]interface{}, len(m)+2)
107 for k, v := range m {
108 metadata[k] = v
109 }
110 // Append compute token metadata in converted form.
111 if val, ok := metadata[authTokenSourceKey].(string); ok && val != "" {
112 metadata[oauth2TokenSourceKey] = val
113 }
114 if val, ok := metadata[authServiceAccountKey].(string); ok && val != "" {
115 metadata[oauth2ServiceAccountKey] = val
116 }
117 tok2 = tok2.WithExtra(metadata)
118 }
119 return tok2, nil
120 }
121
122 // AuthCredentialsFromOauth2Credentials converts a [golang.org/x/oauth2/google.Credentials]
123 // to a [cloud.google.com/go/auth.Credentials].
124 func AuthCredentialsFromOauth2Credentials(creds *google.Credentials) *auth.Credentials {
125 if creds == nil {
126 return nil
127 }
128 return auth.NewCredentials(&auth.CredentialsOptions{
129 TokenProvider: TokenProviderFromTokenSource(creds.TokenSource),
130 JSON: creds.JSON,
131 ProjectIDProvider: auth.CredentialsPropertyFunc(func(ctx context.Context) (string, error) {
132 return creds.ProjectID, nil
133 }),
134 UniverseDomainProvider: auth.CredentialsPropertyFunc(func(ctx context.Context) (string, error) {
135 return creds.GetUniverseDomain()
136 }),
137 })
138 }
139
140 // Oauth2CredentialsFromAuthCredentials converts a [cloud.google.com/go/auth.Credentials]
141 // to a [golang.org/x/oauth2/google.Credentials].
142 func Oauth2CredentialsFromAuthCredentials(creds *auth.Credentials) *google.Credentials {
143 if creds == nil {
144 return nil
145 }
146 // Throw away errors as old credentials are not request aware. Also, no
147 // network requests are currently happening for this use case.
148 projectID, _ := creds.ProjectID(context.Background())
149
150 return &google.Credentials{
151 TokenSource: TokenSourceFromTokenProvider(creds.TokenProvider),
152 ProjectID: projectID,
153 JSON: creds.JSON(),
154 UniverseDomainProvider: func() (string, error) {
155 return creds.UniverseDomain(context.Background())
156 },
157 }
158 }
159
160 type oauth2Error struct {
161 ErrorCode string `json:"error"`
162 ErrorDescription string `json:"error_description"`
163 ErrorURI string `json:"error_uri"`
164 }
165
166 // AddRetrieveErrorToAuthError returns the same error provided and adds a
167 // [golang.org/x/oauth2.RetrieveError] to the error chain by setting the `Err` field on the
168 // [cloud.google.com/go/auth.Error].
169 func AddRetrieveErrorToAuthError(err *auth.Error) *auth.Error {
170 if err == nil {
171 return nil
172 }
173 e := &oauth2.RetrieveError{
174 Response: err.Response,
175 Body: err.Body,
176 }
177 err.Err = e
178 if len(err.Body) > 0 {
179 var oErr oauth2Error
180 // ignore the error as it only fills in extra details
181 json.Unmarshal(err.Body, &oErr)
182 e.ErrorCode = oErr.ErrorCode
183 e.ErrorDescription = oErr.ErrorDescription
184 e.ErrorURI = oErr.ErrorURI
185 }
186 return err
187 }
188
189 // AuthErrorFromRetrieveError returns an [cloud.google.com/go/auth.Error] that
190 // wraps the provided [golang.org/x/oauth2.RetrieveError].
191 func AuthErrorFromRetrieveError(err *oauth2.RetrieveError) *auth.Error {
192 if err == nil {
193 return nil
194 }
195 return &auth.Error{
196 Response: err.Response,
197 Body: err.Body,
198 Err: err,
199 }
200 }
201