sts_exchange.go raw

   1  // Copyright 2020 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package stsexchange
   6  
   7  import (
   8  	"context"
   9  	"encoding/json"
  10  	"fmt"
  11  	"io"
  12  	"net/http"
  13  	"net/url"
  14  	"strconv"
  15  	"strings"
  16  
  17  	"golang.org/x/oauth2"
  18  )
  19  
  20  func defaultHeader() http.Header {
  21  	header := make(http.Header)
  22  	header.Add("Content-Type", "application/x-www-form-urlencoded")
  23  	return header
  24  }
  25  
  26  // ExchangeToken performs an oauth2 token exchange with the provided endpoint.
  27  // The first 4 fields are all mandatory.  headers can be used to pass additional
  28  // headers beyond the bare minimum required by the token exchange.  options can
  29  // be used to pass additional JSON-structured options to the remote server.
  30  func ExchangeToken(ctx context.Context, endpoint string, request *TokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]any) (*Response, error) {
  31  	data := url.Values{}
  32  	data.Set("audience", request.Audience)
  33  	data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
  34  	data.Set("requested_token_type", "urn:ietf:params:oauth:token-type:access_token")
  35  	data.Set("subject_token_type", request.SubjectTokenType)
  36  	data.Set("subject_token", request.SubjectToken)
  37  	data.Set("scope", strings.Join(request.Scope, " "))
  38  	if options != nil {
  39  		opts, err := json.Marshal(options)
  40  		if err != nil {
  41  			return nil, fmt.Errorf("oauth2/google: failed to marshal additional options: %v", err)
  42  		}
  43  		data.Set("options", string(opts))
  44  	}
  45  
  46  	return makeRequest(ctx, endpoint, data, authentication, headers)
  47  }
  48  
  49  func RefreshAccessToken(ctx context.Context, endpoint string, refreshToken string, authentication ClientAuthentication, headers http.Header) (*Response, error) {
  50  	data := url.Values{}
  51  	data.Set("grant_type", "refresh_token")
  52  	data.Set("refresh_token", refreshToken)
  53  
  54  	return makeRequest(ctx, endpoint, data, authentication, headers)
  55  }
  56  
  57  func makeRequest(ctx context.Context, endpoint string, data url.Values, authentication ClientAuthentication, headers http.Header) (*Response, error) {
  58  	if headers == nil {
  59  		headers = defaultHeader()
  60  	}
  61  	client := oauth2.NewClient(ctx, nil)
  62  	authentication.InjectAuthentication(data, headers)
  63  	encodedData := data.Encode()
  64  
  65  	req, err := http.NewRequest("POST", endpoint, strings.NewReader(encodedData))
  66  	if err != nil {
  67  		return nil, fmt.Errorf("oauth2/google: failed to properly build http request: %v", err)
  68  	}
  69  	req = req.WithContext(ctx)
  70  	for key, list := range headers {
  71  		for _, val := range list {
  72  			req.Header.Add(key, val)
  73  		}
  74  	}
  75  	req.Header.Add("Content-Length", strconv.Itoa(len(encodedData)))
  76  
  77  	resp, err := client.Do(req)
  78  
  79  	if err != nil {
  80  		return nil, fmt.Errorf("oauth2/google: invalid response from Secure Token Server: %v", err)
  81  	}
  82  	defer resp.Body.Close()
  83  
  84  	body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
  85  	if err != nil {
  86  		return nil, err
  87  	}
  88  	if c := resp.StatusCode; c < 200 || c > 299 {
  89  		return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body)
  90  	}
  91  	var stsResp Response
  92  	err = json.Unmarshal(body, &stsResp)
  93  	if err != nil {
  94  		return nil, fmt.Errorf("oauth2/google: failed to unmarshal response body from Secure Token Server: %v", err)
  95  
  96  	}
  97  
  98  	return &stsResp, nil
  99  }
 100  
 101  // TokenExchangeRequest contains fields necessary to make an oauth2 token exchange.
 102  type TokenExchangeRequest struct {
 103  	ActingParty struct {
 104  		ActorToken     string
 105  		ActorTokenType string
 106  	}
 107  	GrantType          string
 108  	Resource           string
 109  	Audience           string
 110  	Scope              []string
 111  	RequestedTokenType string
 112  	SubjectToken       string
 113  	SubjectTokenType   string
 114  }
 115  
 116  // Response is used to decode the remote server response during an oauth2 token exchange.
 117  type Response struct {
 118  	AccessToken     string `json:"access_token"`
 119  	IssuedTokenType string `json:"issued_token_type"`
 120  	TokenType       string `json:"token_type"`
 121  	ExpiresIn       int    `json:"expires_in"`
 122  	Scope           string `json:"scope"`
 123  	RefreshToken    string `json:"refresh_token"`
 124  }
 125