threelegged.go raw

   1  // Copyright 2023 Google LLC
   2  //
   3  // Licensed under the Apache License, Version 2.0 (the "License");
   4  // you may not use this file except in compliance with the License.
   5  // You may obtain a copy of the License at
   6  //
   7  //      http://www.apache.org/licenses/LICENSE-2.0
   8  //
   9  // Unless required by applicable law or agreed to in writing, software
  10  // distributed under the License is distributed on an "AS IS" BASIS,
  11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12  // See the License for the specific language governing permissions and
  13  // limitations under the License.
  14  
  15  package auth
  16  
  17  import (
  18  	"bytes"
  19  	"context"
  20  	"encoding/json"
  21  	"errors"
  22  	"fmt"
  23  	"log/slog"
  24  	"mime"
  25  	"net/http"
  26  	"net/url"
  27  	"strconv"
  28  	"strings"
  29  	"time"
  30  
  31  	"cloud.google.com/go/auth/internal"
  32  	"github.com/googleapis/gax-go/v2/internallog"
  33  )
  34  
  35  // AuthorizationHandler is a 3-legged-OAuth helper that prompts the user for
  36  // OAuth consent at the specified auth code URL and returns an auth code and
  37  // state upon approval.
  38  type AuthorizationHandler func(authCodeURL string) (code string, state string, err error)
  39  
  40  // Options3LO are the options for doing a 3-legged OAuth2 flow.
  41  type Options3LO struct {
  42  	// ClientID is the application's ID.
  43  	ClientID string
  44  	// ClientSecret is the application's secret. Not required if AuthHandlerOpts
  45  	// is set.
  46  	ClientSecret string
  47  	// AuthURL is the URL for authenticating.
  48  	AuthURL string
  49  	// TokenURL is the URL for retrieving a token.
  50  	TokenURL string
  51  	// AuthStyle is used to describe how to client info in the token request.
  52  	AuthStyle Style
  53  	// RefreshToken is the token used to refresh the credential. Not required
  54  	// if AuthHandlerOpts is set.
  55  	RefreshToken string
  56  	// RedirectURL is the URL to redirect users to. Optional.
  57  	RedirectURL string
  58  	// Scopes specifies requested permissions for the Token. Optional.
  59  	Scopes []string
  60  
  61  	// URLParams are the set of values to apply to the token exchange. Optional.
  62  	URLParams url.Values
  63  	// Client is the client to be used to make the underlying token requests.
  64  	// Optional.
  65  	Client *http.Client
  66  	// EarlyTokenExpiry is the time before the token expires that it should be
  67  	// refreshed. If not set the default value is 3 minutes and 45 seconds.
  68  	// Optional.
  69  	EarlyTokenExpiry time.Duration
  70  
  71  	// AuthHandlerOpts provides a set of options for doing a
  72  	// 3-legged OAuth2 flow with a custom [AuthorizationHandler]. Optional.
  73  	AuthHandlerOpts *AuthorizationHandlerOptions
  74  	// Logger is used for debug logging. If provided, logging will be enabled
  75  	// at the loggers configured level. By default logging is disabled unless
  76  	// enabled by setting GOOGLE_SDK_GO_LOGGING_LEVEL in which case a default
  77  	// logger will be used. Optional.
  78  	Logger *slog.Logger
  79  }
  80  
  81  func (o *Options3LO) validate() error {
  82  	if o == nil {
  83  		return errors.New("auth: options must be provided")
  84  	}
  85  	if o.ClientID == "" {
  86  		return errors.New("auth: client ID must be provided")
  87  	}
  88  	if o.AuthHandlerOpts == nil && o.ClientSecret == "" {
  89  		return errors.New("auth: client secret must be provided")
  90  	}
  91  	if o.AuthURL == "" {
  92  		return errors.New("auth: auth URL must be provided")
  93  	}
  94  	if o.TokenURL == "" {
  95  		return errors.New("auth: token URL must be provided")
  96  	}
  97  	if o.AuthStyle == StyleUnknown {
  98  		return errors.New("auth: auth style must be provided")
  99  	}
 100  	if o.AuthHandlerOpts == nil && o.RefreshToken == "" {
 101  		return errors.New("auth: refresh token must be provided")
 102  	}
 103  	return nil
 104  }
 105  
 106  func (o *Options3LO) logger() *slog.Logger {
 107  	return internallog.New(o.Logger)
 108  }
 109  
 110  // PKCEOptions holds parameters to support PKCE.
 111  type PKCEOptions struct {
 112  	// Challenge is the un-padded, base64-url-encoded string of the encrypted code verifier.
 113  	Challenge string // The un-padded, base64-url-encoded string of the encrypted code verifier.
 114  	// ChallengeMethod is the encryption method (ex. S256).
 115  	ChallengeMethod string
 116  	// Verifier is the original, non-encrypted secret.
 117  	Verifier string // The original, non-encrypted secret.
 118  }
 119  
 120  type tokenJSON struct {
 121  	AccessToken  string `json:"access_token"`
 122  	TokenType    string `json:"token_type"`
 123  	RefreshToken string `json:"refresh_token"`
 124  	ExpiresIn    int    `json:"expires_in"`
 125  	// error fields
 126  	ErrorCode        string `json:"error"`
 127  	ErrorDescription string `json:"error_description"`
 128  	ErrorURI         string `json:"error_uri"`
 129  }
 130  
 131  func (e *tokenJSON) expiry() (t time.Time) {
 132  	if v := e.ExpiresIn; v != 0 {
 133  		return time.Now().Add(time.Duration(v) * time.Second)
 134  	}
 135  	return
 136  }
 137  
 138  func (o *Options3LO) client() *http.Client {
 139  	if o.Client != nil {
 140  		return o.Client
 141  	}
 142  	return internal.DefaultClient()
 143  }
 144  
 145  // authCodeURL returns a URL that points to a OAuth2 consent page.
 146  func (o *Options3LO) authCodeURL(state string, values url.Values) string {
 147  	var buf bytes.Buffer
 148  	buf.WriteString(o.AuthURL)
 149  	v := url.Values{
 150  		"response_type": {"code"},
 151  		"client_id":     {o.ClientID},
 152  	}
 153  	if o.RedirectURL != "" {
 154  		v.Set("redirect_uri", o.RedirectURL)
 155  	}
 156  	if len(o.Scopes) > 0 {
 157  		v.Set("scope", strings.Join(o.Scopes, " "))
 158  	}
 159  	if state != "" {
 160  		v.Set("state", state)
 161  	}
 162  	if o.AuthHandlerOpts != nil {
 163  		if o.AuthHandlerOpts.PKCEOpts != nil &&
 164  			o.AuthHandlerOpts.PKCEOpts.Challenge != "" {
 165  			v.Set(codeChallengeKey, o.AuthHandlerOpts.PKCEOpts.Challenge)
 166  		}
 167  		if o.AuthHandlerOpts.PKCEOpts != nil &&
 168  			o.AuthHandlerOpts.PKCEOpts.ChallengeMethod != "" {
 169  			v.Set(codeChallengeMethodKey, o.AuthHandlerOpts.PKCEOpts.ChallengeMethod)
 170  		}
 171  	}
 172  	for k := range values {
 173  		v.Set(k, v.Get(k))
 174  	}
 175  	if strings.Contains(o.AuthURL, "?") {
 176  		buf.WriteByte('&')
 177  	} else {
 178  		buf.WriteByte('?')
 179  	}
 180  	buf.WriteString(v.Encode())
 181  	return buf.String()
 182  }
 183  
 184  // New3LOTokenProvider returns a [TokenProvider] based on the 3-legged OAuth2
 185  // configuration. The TokenProvider is caches and auto-refreshes tokens by
 186  // default.
 187  func New3LOTokenProvider(opts *Options3LO) (TokenProvider, error) {
 188  	if err := opts.validate(); err != nil {
 189  		return nil, err
 190  	}
 191  	if opts.AuthHandlerOpts != nil {
 192  		return new3LOTokenProviderWithAuthHandler(opts), nil
 193  	}
 194  	return NewCachedTokenProvider(&tokenProvider3LO{opts: opts, refreshToken: opts.RefreshToken, client: opts.client()}, &CachedTokenProviderOptions{
 195  		ExpireEarly: opts.EarlyTokenExpiry,
 196  	}), nil
 197  }
 198  
 199  // AuthorizationHandlerOptions provides a set of options to specify for doing a
 200  // 3-legged OAuth2 flow with a custom [AuthorizationHandler].
 201  type AuthorizationHandlerOptions struct {
 202  	// AuthorizationHandler specifies the handler used to for the authorization
 203  	// part of the flow.
 204  	Handler AuthorizationHandler
 205  	// State is used verify that the "state" is identical in the request and
 206  	// response before exchanging the auth code for OAuth2 token.
 207  	State string
 208  	// PKCEOpts allows setting configurations for PKCE. Optional.
 209  	PKCEOpts *PKCEOptions
 210  }
 211  
 212  func new3LOTokenProviderWithAuthHandler(opts *Options3LO) TokenProvider {
 213  	return NewCachedTokenProvider(&tokenProviderWithHandler{opts: opts, state: opts.AuthHandlerOpts.State}, &CachedTokenProviderOptions{
 214  		ExpireEarly: opts.EarlyTokenExpiry,
 215  	})
 216  }
 217  
 218  // exchange handles the final exchange portion of the 3lo flow. Returns a Token,
 219  // refreshToken, and error.
 220  func (o *Options3LO) exchange(ctx context.Context, code string) (*Token, string, error) {
 221  	// Build request
 222  	v := url.Values{
 223  		"grant_type": {"authorization_code"},
 224  		"code":       {code},
 225  	}
 226  	if o.RedirectURL != "" {
 227  		v.Set("redirect_uri", o.RedirectURL)
 228  	}
 229  	if o.AuthHandlerOpts != nil &&
 230  		o.AuthHandlerOpts.PKCEOpts != nil &&
 231  		o.AuthHandlerOpts.PKCEOpts.Verifier != "" {
 232  		v.Set(codeVerifierKey, o.AuthHandlerOpts.PKCEOpts.Verifier)
 233  	}
 234  	for k := range o.URLParams {
 235  		v.Set(k, o.URLParams.Get(k))
 236  	}
 237  	return fetchToken(ctx, o, v)
 238  }
 239  
 240  // This struct is not safe for concurrent access alone, but the way it is used
 241  // in this package by wrapping it with a cachedTokenProvider makes it so.
 242  type tokenProvider3LO struct {
 243  	opts         *Options3LO
 244  	client       *http.Client
 245  	refreshToken string
 246  }
 247  
 248  func (tp *tokenProvider3LO) Token(ctx context.Context) (*Token, error) {
 249  	if tp.refreshToken == "" {
 250  		return nil, errors.New("auth: token expired and refresh token is not set")
 251  	}
 252  	v := url.Values{
 253  		"grant_type":    {"refresh_token"},
 254  		"refresh_token": {tp.refreshToken},
 255  	}
 256  	for k := range tp.opts.URLParams {
 257  		v.Set(k, tp.opts.URLParams.Get(k))
 258  	}
 259  
 260  	tk, rt, err := fetchToken(ctx, tp.opts, v)
 261  	if err != nil {
 262  		return nil, err
 263  	}
 264  	if tp.refreshToken != rt && rt != "" {
 265  		tp.refreshToken = rt
 266  	}
 267  	return tk, err
 268  }
 269  
 270  type tokenProviderWithHandler struct {
 271  	opts  *Options3LO
 272  	state string
 273  }
 274  
 275  func (tp tokenProviderWithHandler) Token(ctx context.Context) (*Token, error) {
 276  	url := tp.opts.authCodeURL(tp.state, nil)
 277  	code, state, err := tp.opts.AuthHandlerOpts.Handler(url)
 278  	if err != nil {
 279  		return nil, err
 280  	}
 281  	if state != tp.state {
 282  		return nil, errors.New("auth: state mismatch in 3-legged-OAuth flow")
 283  	}
 284  	tok, _, err := tp.opts.exchange(ctx, code)
 285  	return tok, err
 286  }
 287  
 288  // fetchToken returns a Token, refresh token, and/or an error.
 289  func fetchToken(ctx context.Context, o *Options3LO, v url.Values) (*Token, string, error) {
 290  	var refreshToken string
 291  	if o.AuthStyle == StyleInParams {
 292  		if o.ClientID != "" {
 293  			v.Set("client_id", o.ClientID)
 294  		}
 295  		if o.ClientSecret != "" {
 296  			v.Set("client_secret", o.ClientSecret)
 297  		}
 298  	}
 299  	req, err := http.NewRequestWithContext(ctx, "POST", o.TokenURL, strings.NewReader(v.Encode()))
 300  	if err != nil {
 301  		return nil, refreshToken, err
 302  	}
 303  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 304  	if o.AuthStyle == StyleInHeader {
 305  		req.SetBasicAuth(url.QueryEscape(o.ClientID), url.QueryEscape(o.ClientSecret))
 306  	}
 307  	logger := o.logger()
 308  
 309  	logger.DebugContext(ctx, "3LO token request", "request", internallog.HTTPRequest(req, []byte(v.Encode())))
 310  	// Make request
 311  	resp, body, err := internal.DoRequest(o.client(), req)
 312  	if err != nil {
 313  		return nil, refreshToken, err
 314  	}
 315  	logger.DebugContext(ctx, "3LO token response", "response", internallog.HTTPResponse(resp, body))
 316  	failureStatus := resp.StatusCode < 200 || resp.StatusCode > 299
 317  	tokError := &Error{
 318  		Response: resp,
 319  		Body:     body,
 320  	}
 321  
 322  	var token *Token
 323  	// errors ignored because of default switch on content
 324  	content, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
 325  	switch content {
 326  	case "application/x-www-form-urlencoded", "text/plain":
 327  		// some endpoints return a query string
 328  		vals, err := url.ParseQuery(string(body))
 329  		if err != nil {
 330  			if failureStatus {
 331  				return nil, refreshToken, tokError
 332  			}
 333  			return nil, refreshToken, fmt.Errorf("auth: cannot parse response: %w", err)
 334  		}
 335  		tokError.code = vals.Get("error")
 336  		tokError.description = vals.Get("error_description")
 337  		tokError.uri = vals.Get("error_uri")
 338  		token = &Token{
 339  			Value:    vals.Get("access_token"),
 340  			Type:     vals.Get("token_type"),
 341  			Metadata: make(map[string]interface{}, len(vals)),
 342  		}
 343  		for k, v := range vals {
 344  			token.Metadata[k] = v
 345  		}
 346  		refreshToken = vals.Get("refresh_token")
 347  		e := vals.Get("expires_in")
 348  		expires, _ := strconv.Atoi(e)
 349  		if expires != 0 {
 350  			token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
 351  		}
 352  	default:
 353  		var tj tokenJSON
 354  		if err = json.Unmarshal(body, &tj); err != nil {
 355  			if failureStatus {
 356  				return nil, refreshToken, tokError
 357  			}
 358  			return nil, refreshToken, fmt.Errorf("auth: cannot parse json: %w", err)
 359  		}
 360  		tokError.code = tj.ErrorCode
 361  		tokError.description = tj.ErrorDescription
 362  		tokError.uri = tj.ErrorURI
 363  		token = &Token{
 364  			Value:    tj.AccessToken,
 365  			Type:     tj.TokenType,
 366  			Expiry:   tj.expiry(),
 367  			Metadata: make(map[string]interface{}),
 368  		}
 369  		json.Unmarshal(body, &token.Metadata) // optional field, skip err check
 370  		refreshToken = tj.RefreshToken
 371  	}
 372  	// according to spec, servers should respond status 400 in error case
 373  	// https://www.rfc-editor.org/rfc/rfc6749#section-5.2
 374  	// but some unorthodox servers respond 200 in error case
 375  	if failureStatus || tokError.code != "" {
 376  		return nil, refreshToken, tokError
 377  	}
 378  	if token.Value == "" {
 379  		return nil, refreshToken, errors.New("auth: server response missing access_token")
 380  	}
 381  	return token, refreshToken, nil
 382  }
 383