sts_exchange.go raw

   1  // Copyright 2023 Google LLC
   2  //
   3  // Licensed under the Apache License, Version 2.0 (the "License");
   4  // you may not use this file except in compliance with the License.
   5  // You may obtain a copy of the License at
   6  //
   7  //      http://www.apache.org/licenses/LICENSE-2.0
   8  //
   9  // Unless required by applicable law or agreed to in writing, software
  10  // distributed under the License is distributed on an "AS IS" BASIS,
  11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12  // See the License for the specific language governing permissions and
  13  // limitations under the License.
  14  
  15  package stsexchange
  16  
  17  import (
  18  	"context"
  19  	"encoding/base64"
  20  	"encoding/json"
  21  	"fmt"
  22  	"log/slog"
  23  	"net/http"
  24  	"net/url"
  25  	"strconv"
  26  	"strings"
  27  
  28  	"cloud.google.com/go/auth"
  29  	"cloud.google.com/go/auth/internal"
  30  	"github.com/googleapis/gax-go/v2/internallog"
  31  )
  32  
  33  const (
  34  	// GrantType for a sts exchange.
  35  	GrantType = "urn:ietf:params:oauth:grant-type:token-exchange"
  36  	// TokenType for a sts exchange.
  37  	TokenType = "urn:ietf:params:oauth:token-type:access_token"
  38  
  39  	jwtTokenType = "urn:ietf:params:oauth:token-type:jwt"
  40  )
  41  
  42  // Options stores the configuration for making an sts exchange request.
  43  type Options struct {
  44  	Client         *http.Client
  45  	Logger         *slog.Logger
  46  	Endpoint       string
  47  	Request        *TokenRequest
  48  	Authentication ClientAuthentication
  49  	Headers        http.Header
  50  	// ExtraOpts are optional fields marshalled into the `options` field of the
  51  	// request body.
  52  	ExtraOpts    map[string]interface{}
  53  	RefreshToken string
  54  }
  55  
  56  // RefreshAccessToken performs the token exchange using a refresh token flow.
  57  func RefreshAccessToken(ctx context.Context, opts *Options) (*TokenResponse, error) {
  58  	data := url.Values{}
  59  	data.Set("grant_type", "refresh_token")
  60  	data.Set("refresh_token", opts.RefreshToken)
  61  	return doRequest(ctx, opts, data)
  62  }
  63  
  64  // ExchangeToken performs an oauth2 token exchange with the provided endpoint.
  65  func ExchangeToken(ctx context.Context, opts *Options) (*TokenResponse, error) {
  66  	data := url.Values{}
  67  	data.Set("audience", opts.Request.Audience)
  68  	data.Set("grant_type", GrantType)
  69  	data.Set("requested_token_type", TokenType)
  70  	data.Set("subject_token_type", opts.Request.SubjectTokenType)
  71  	data.Set("subject_token", opts.Request.SubjectToken)
  72  	data.Set("scope", strings.Join(opts.Request.Scope, " "))
  73  	if opts.ExtraOpts != nil {
  74  		opts, err := json.Marshal(opts.ExtraOpts)
  75  		if err != nil {
  76  			return nil, fmt.Errorf("credentials: failed to marshal additional options: %w", err)
  77  		}
  78  		data.Set("options", string(opts))
  79  	}
  80  	return doRequest(ctx, opts, data)
  81  }
  82  
  83  func doRequest(ctx context.Context, opts *Options, data url.Values) (*TokenResponse, error) {
  84  	opts.Authentication.InjectAuthentication(data, opts.Headers)
  85  	encodedData := data.Encode()
  86  	logger := internallog.New(opts.Logger)
  87  
  88  	req, err := http.NewRequestWithContext(ctx, "POST", opts.Endpoint, strings.NewReader(encodedData))
  89  	if err != nil {
  90  		return nil, fmt.Errorf("credentials: failed to properly build http request: %w", err)
  91  
  92  	}
  93  	for key, list := range opts.Headers {
  94  		for _, val := range list {
  95  			req.Header.Add(key, val)
  96  		}
  97  	}
  98  	req.Header.Set("Content-Length", strconv.Itoa(len(encodedData)))
  99  
 100  	logger.DebugContext(ctx, "sts token request", "request", internallog.HTTPRequest(req, []byte(encodedData)))
 101  	resp, body, err := internal.DoRequest(opts.Client, req)
 102  	if err != nil {
 103  		return nil, fmt.Errorf("credentials: invalid response from Secure Token Server: %w", err)
 104  	}
 105  	logger.DebugContext(ctx, "sts token response", "response", internallog.HTTPResponse(resp, body))
 106  	if c := resp.StatusCode; c < http.StatusOK || c > http.StatusMultipleChoices {
 107  		return nil, fmt.Errorf("credentials: status code %d: %s", c, body)
 108  	}
 109  	var stsResp TokenResponse
 110  	if err := json.Unmarshal(body, &stsResp); err != nil {
 111  		return nil, fmt.Errorf("credentials: failed to unmarshal response body from Secure Token Server: %w", err)
 112  	}
 113  
 114  	return &stsResp, nil
 115  }
 116  
 117  // TokenRequest contains fields necessary to make an oauth2 token
 118  // exchange.
 119  type TokenRequest struct {
 120  	ActingParty struct {
 121  		ActorToken     string
 122  		ActorTokenType string
 123  	}
 124  	GrantType          string
 125  	Resource           string
 126  	Audience           string
 127  	Scope              []string
 128  	RequestedTokenType string
 129  	SubjectToken       string
 130  	SubjectTokenType   string
 131  }
 132  
 133  // TokenResponse is used to decode the remote server response during
 134  // an oauth2 token exchange.
 135  type TokenResponse struct {
 136  	AccessToken     string `json:"access_token"`
 137  	IssuedTokenType string `json:"issued_token_type"`
 138  	TokenType       string `json:"token_type"`
 139  	ExpiresIn       int    `json:"expires_in"`
 140  	Scope           string `json:"scope"`
 141  	RefreshToken    string `json:"refresh_token"`
 142  }
 143  
 144  // ClientAuthentication represents an OAuth client ID and secret and the
 145  // mechanism for passing these credentials as stated in rfc6749#2.3.1.
 146  type ClientAuthentication struct {
 147  	AuthStyle    auth.Style
 148  	ClientID     string
 149  	ClientSecret string
 150  }
 151  
 152  // InjectAuthentication is used to add authentication to a Secure Token Service
 153  // exchange request.  It modifies either the passed url.Values or http.Header
 154  // depending on the desired authentication format.
 155  func (c *ClientAuthentication) InjectAuthentication(values url.Values, headers http.Header) {
 156  	if c.ClientID == "" || c.ClientSecret == "" || values == nil || headers == nil {
 157  		return
 158  	}
 159  	switch c.AuthStyle {
 160  	case auth.StyleInHeader:
 161  		plainHeader := c.ClientID + ":" + c.ClientSecret
 162  		headers.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(plainHeader)))
 163  	default:
 164  		values.Set("client_id", c.ClientID)
 165  		values.Set("client_secret", c.ClientSecret)
 166  	}
 167  }
 168