user.go raw

   1  // Copyright 2021 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package impersonate
   6  
   7  import (
   8  	"bytes"
   9  	"context"
  10  	"encoding/json"
  11  	"fmt"
  12  	"io"
  13  	"net/http"
  14  	"net/url"
  15  	"strings"
  16  	"time"
  17  
  18  	"golang.org/x/oauth2"
  19  )
  20  
  21  // user provides an auth flow for domain-wide delegation, setting
  22  // CredentialsConfig.Subject to be the impersonated user.
  23  func user(ctx context.Context, c CredentialsConfig, client *http.Client, lifetime time.Duration, isStaticToken bool) (oauth2.TokenSource, error) {
  24  	u := userTokenSource{
  25  		client:          client,
  26  		targetPrincipal: c.TargetPrincipal,
  27  		subject:         c.Subject,
  28  		lifetime:        lifetime,
  29  	}
  30  	u.delegates = make([]string, len(c.Delegates))
  31  	for i, v := range c.Delegates {
  32  		u.delegates[i] = formatIAMServiceAccountName(v)
  33  	}
  34  	u.scopes = make([]string, len(c.Scopes))
  35  	copy(u.scopes, c.Scopes)
  36  	if isStaticToken {
  37  		tok, err := u.Token()
  38  		if err != nil {
  39  			return nil, err
  40  		}
  41  		return oauth2.StaticTokenSource(tok), nil
  42  	}
  43  	return oauth2.ReuseTokenSource(nil, u), nil
  44  }
  45  
  46  type claimSet struct {
  47  	Iss   string `json:"iss"`
  48  	Scope string `json:"scope,omitempty"`
  49  	Sub   string `json:"sub,omitempty"`
  50  	Aud   string `json:"aud"`
  51  	Iat   int64  `json:"iat"`
  52  	Exp   int64  `json:"exp"`
  53  }
  54  
  55  type signJWTRequest struct {
  56  	Payload   string   `json:"payload"`
  57  	Delegates []string `json:"delegates,omitempty"`
  58  }
  59  
  60  type signJWTResponse struct {
  61  	// KeyID is the key used to sign the JWT.
  62  	KeyID string `json:"keyId"`
  63  	// SignedJwt contains the automatically generated header; the
  64  	// client-supplied payload; and the signature, which is generated using
  65  	// the key referenced by the `kid` field in the header.
  66  	SignedJWT string `json:"signedJwt"`
  67  }
  68  
  69  type exchangeTokenResponse struct {
  70  	AccessToken string `json:"access_token"`
  71  	TokenType   string `json:"token_type"`
  72  	ExpiresIn   int64  `json:"expires_in"`
  73  }
  74  
  75  type userTokenSource struct {
  76  	client *http.Client
  77  
  78  	targetPrincipal string
  79  	subject         string
  80  	scopes          []string
  81  	lifetime        time.Duration
  82  	delegates       []string
  83  }
  84  
  85  func (u userTokenSource) Token() (*oauth2.Token, error) {
  86  	signedJWT, err := u.signJWT()
  87  	if err != nil {
  88  		return nil, err
  89  	}
  90  	return u.exchangeToken(signedJWT)
  91  }
  92  
  93  func (u userTokenSource) signJWT() (string, error) {
  94  	now := time.Now()
  95  	exp := now.Add(u.lifetime)
  96  	claims := claimSet{
  97  		Iss:   u.targetPrincipal,
  98  		Scope: strings.Join(u.scopes, " "),
  99  		Sub:   u.subject,
 100  		Aud:   fmt.Sprintf("%s/token", oauth2Endpoint),
 101  		Iat:   now.Unix(),
 102  		Exp:   exp.Unix(),
 103  	}
 104  	payloadBytes, err := json.Marshal(claims)
 105  	if err != nil {
 106  		return "", fmt.Errorf("impersonate: unable to marshal claims: %v", err)
 107  	}
 108  	signJWTReq := signJWTRequest{
 109  		Payload:   string(payloadBytes),
 110  		Delegates: u.delegates,
 111  	}
 112  
 113  	bodyBytes, err := json.Marshal(signJWTReq)
 114  	if err != nil {
 115  		return "", fmt.Errorf("impersonate: unable to marshal request: %v", err)
 116  	}
 117  	reqURL := fmt.Sprintf("%s/v1/%s:signJwt", iamCredentailsEndpoint, formatIAMServiceAccountName(u.targetPrincipal))
 118  	req, err := http.NewRequest("POST", reqURL, bytes.NewReader(bodyBytes))
 119  	if err != nil {
 120  		return "", fmt.Errorf("impersonate: unable to create request: %v", err)
 121  	}
 122  	req.Header.Set("Content-Type", "application/json")
 123  	rawResp, err := u.client.Do(req)
 124  	if err != nil {
 125  		return "", fmt.Errorf("impersonate: unable to sign JWT: %v", err)
 126  	}
 127  	body, err := io.ReadAll(io.LimitReader(rawResp.Body, 1<<20))
 128  	if err != nil {
 129  		return "", fmt.Errorf("impersonate: unable to read body: %v", err)
 130  	}
 131  	if c := rawResp.StatusCode; c < 200 || c > 299 {
 132  		return "", fmt.Errorf("impersonate: status code %d: %s", c, body)
 133  	}
 134  
 135  	var signJWTResp signJWTResponse
 136  	if err := json.Unmarshal(body, &signJWTResp); err != nil {
 137  		return "", fmt.Errorf("impersonate: unable to parse response: %v", err)
 138  	}
 139  	return signJWTResp.SignedJWT, nil
 140  }
 141  
 142  func (u userTokenSource) exchangeToken(signedJWT string) (*oauth2.Token, error) {
 143  	now := time.Now()
 144  	v := url.Values{}
 145  	v.Set("grant_type", "assertion")
 146  	v.Set("assertion_type", "http://oauth.net/grant_type/jwt/1.0/bearer")
 147  	v.Set("assertion", signedJWT)
 148  	rawResp, err := u.client.PostForm(fmt.Sprintf("%s/token", oauth2Endpoint), v)
 149  	if err != nil {
 150  		return nil, fmt.Errorf("impersonate: unable to exchange token: %v", err)
 151  	}
 152  	body, err := io.ReadAll(io.LimitReader(rawResp.Body, 1<<20))
 153  	if err != nil {
 154  		return nil, fmt.Errorf("impersonate: unable to read body: %v", err)
 155  	}
 156  	if c := rawResp.StatusCode; c < 200 || c > 299 {
 157  		return nil, fmt.Errorf("impersonate: status code %d: %s", c, body)
 158  	}
 159  
 160  	var tokenResp exchangeTokenResponse
 161  	if err := json.Unmarshal(body, &tokenResp); err != nil {
 162  		return nil, fmt.Errorf("impersonate: unable to parse response: %v", err)
 163  	}
 164  
 165  	return &oauth2.Token{
 166  		AccessToken: tokenResp.AccessToken,
 167  		TokenType:   tokenResp.TokenType,
 168  		Expiry:      now.Add(time.Second * time.Duration(tokenResp.ExpiresIn)),
 169  	}, nil
 170  }
 171