sts_exchange.go raw
1 // Copyright 2020 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package stsexchange
6
7 import (
8 "context"
9 "encoding/json"
10 "fmt"
11 "io"
12 "net/http"
13 "net/url"
14 "strconv"
15 "strings"
16
17 "golang.org/x/oauth2"
18 )
19
20 func defaultHeader() http.Header {
21 header := make(http.Header)
22 header.Add("Content-Type", "application/x-www-form-urlencoded")
23 return header
24 }
25
26 // ExchangeToken performs an oauth2 token exchange with the provided endpoint.
27 // The first 4 fields are all mandatory. headers can be used to pass additional
28 // headers beyond the bare minimum required by the token exchange. options can
29 // be used to pass additional JSON-structured options to the remote server.
30 func ExchangeToken(ctx context.Context, endpoint string, request *TokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]any) (*Response, error) {
31 data := url.Values{}
32 data.Set("audience", request.Audience)
33 data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
34 data.Set("requested_token_type", "urn:ietf:params:oauth:token-type:access_token")
35 data.Set("subject_token_type", request.SubjectTokenType)
36 data.Set("subject_token", request.SubjectToken)
37 data.Set("scope", strings.Join(request.Scope, " "))
38 if options != nil {
39 opts, err := json.Marshal(options)
40 if err != nil {
41 return nil, fmt.Errorf("oauth2/google: failed to marshal additional options: %v", err)
42 }
43 data.Set("options", string(opts))
44 }
45
46 return makeRequest(ctx, endpoint, data, authentication, headers)
47 }
48
49 func RefreshAccessToken(ctx context.Context, endpoint string, refreshToken string, authentication ClientAuthentication, headers http.Header) (*Response, error) {
50 data := url.Values{}
51 data.Set("grant_type", "refresh_token")
52 data.Set("refresh_token", refreshToken)
53
54 return makeRequest(ctx, endpoint, data, authentication, headers)
55 }
56
57 func makeRequest(ctx context.Context, endpoint string, data url.Values, authentication ClientAuthentication, headers http.Header) (*Response, error) {
58 if headers == nil {
59 headers = defaultHeader()
60 }
61 client := oauth2.NewClient(ctx, nil)
62 authentication.InjectAuthentication(data, headers)
63 encodedData := data.Encode()
64
65 req, err := http.NewRequest("POST", endpoint, strings.NewReader(encodedData))
66 if err != nil {
67 return nil, fmt.Errorf("oauth2/google: failed to properly build http request: %v", err)
68 }
69 req = req.WithContext(ctx)
70 for key, list := range headers {
71 for _, val := range list {
72 req.Header.Add(key, val)
73 }
74 }
75 req.Header.Add("Content-Length", strconv.Itoa(len(encodedData)))
76
77 resp, err := client.Do(req)
78
79 if err != nil {
80 return nil, fmt.Errorf("oauth2/google: invalid response from Secure Token Server: %v", err)
81 }
82 defer resp.Body.Close()
83
84 body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
85 if err != nil {
86 return nil, err
87 }
88 if c := resp.StatusCode; c < 200 || c > 299 {
89 return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body)
90 }
91 var stsResp Response
92 err = json.Unmarshal(body, &stsResp)
93 if err != nil {
94 return nil, fmt.Errorf("oauth2/google: failed to unmarshal response body from Secure Token Server: %v", err)
95
96 }
97
98 return &stsResp, nil
99 }
100
101 // TokenExchangeRequest contains fields necessary to make an oauth2 token exchange.
102 type TokenExchangeRequest struct {
103 ActingParty struct {
104 ActorToken string
105 ActorTokenType string
106 }
107 GrantType string
108 Resource string
109 Audience string
110 Scope []string
111 RequestedTokenType string
112 SubjectToken string
113 SubjectTokenType string
114 }
115
116 // Response is used to decode the remote server response during an oauth2 token exchange.
117 type Response struct {
118 AccessToken string `json:"access_token"`
119 IssuedTokenType string `json:"issued_token_type"`
120 TokenType string `json:"token_type"`
121 ExpiresIn int `json:"expires_in"`
122 Scope string `json:"scope"`
123 RefreshToken string `json:"refresh_token"`
124 }
125