gdch.go raw

   1  // Copyright 2023 Google LLC
   2  //
   3  // Licensed under the Apache License, Version 2.0 (the "License");
   4  // you may not use this file except in compliance with the License.
   5  // You may obtain a copy of the License at
   6  //
   7  //      http://www.apache.org/licenses/LICENSE-2.0
   8  //
   9  // Unless required by applicable law or agreed to in writing, software
  10  // distributed under the License is distributed on an "AS IS" BASIS,
  11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12  // See the License for the specific language governing permissions and
  13  // limitations under the License.
  14  
  15  package gdch
  16  
  17  import (
  18  	"context"
  19  	"crypto"
  20  	"crypto/tls"
  21  	"crypto/x509"
  22  	"encoding/json"
  23  	"errors"
  24  	"fmt"
  25  	"log/slog"
  26  	"net/http"
  27  	"net/url"
  28  	"os"
  29  	"strings"
  30  	"time"
  31  
  32  	"cloud.google.com/go/auth"
  33  	"cloud.google.com/go/auth/internal"
  34  	"cloud.google.com/go/auth/internal/credsfile"
  35  	"cloud.google.com/go/auth/internal/jwt"
  36  	"github.com/googleapis/gax-go/v2/internallog"
  37  )
  38  
  39  const (
  40  	// GrantType is the grant type for the token request.
  41  	GrantType        = "urn:ietf:params:oauth:token-type:token-exchange"
  42  	requestTokenType = "urn:ietf:params:oauth:token-type:access_token"
  43  	subjectTokenType = "urn:k8s:params:oauth:token-type:serviceaccount"
  44  )
  45  
  46  var (
  47  	gdchSupportFormatVersions map[string]bool = map[string]bool{
  48  		"1": true,
  49  	}
  50  )
  51  
  52  // Options for [NewTokenProvider].
  53  type Options struct {
  54  	STSAudience string
  55  	Client      *http.Client
  56  	Logger      *slog.Logger
  57  }
  58  
  59  // NewTokenProvider returns a [cloud.google.com/go/auth.TokenProvider] from a
  60  // GDCH cred file.
  61  func NewTokenProvider(f *credsfile.GDCHServiceAccountFile, o *Options) (auth.TokenProvider, error) {
  62  	if !gdchSupportFormatVersions[f.FormatVersion] {
  63  		return nil, fmt.Errorf("credentials: unsupported gdch_service_account format %q", f.FormatVersion)
  64  	}
  65  	if o.STSAudience == "" {
  66  		return nil, errors.New("credentials: STSAudience must be set for the GDCH auth flows")
  67  	}
  68  	signer, err := internal.ParseKey([]byte(f.PrivateKey))
  69  	if err != nil {
  70  		return nil, err
  71  	}
  72  	certPool, err := loadCertPool(f.CertPath)
  73  	if err != nil {
  74  		return nil, err
  75  	}
  76  
  77  	tp := gdchProvider{
  78  		serviceIdentity: fmt.Sprintf("system:serviceaccount:%s:%s", f.Project, f.Name),
  79  		tokenURL:        f.TokenURL,
  80  		aud:             o.STSAudience,
  81  		signer:          signer,
  82  		pkID:            f.PrivateKeyID,
  83  		certPool:        certPool,
  84  		client:          o.Client,
  85  		logger:          internallog.New(o.Logger),
  86  	}
  87  	return tp, nil
  88  }
  89  
  90  func loadCertPool(path string) (*x509.CertPool, error) {
  91  	pool := x509.NewCertPool()
  92  	pem, err := os.ReadFile(path)
  93  	if err != nil {
  94  		return nil, fmt.Errorf("credentials: failed to read certificate: %w", err)
  95  	}
  96  	pool.AppendCertsFromPEM(pem)
  97  	return pool, nil
  98  }
  99  
 100  type gdchProvider struct {
 101  	serviceIdentity string
 102  	tokenURL        string
 103  	aud             string
 104  	signer          crypto.Signer
 105  	pkID            string
 106  	certPool        *x509.CertPool
 107  
 108  	client *http.Client
 109  	logger *slog.Logger
 110  }
 111  
 112  func (g gdchProvider) Token(ctx context.Context) (*auth.Token, error) {
 113  	addCertToTransport(g.client, g.certPool)
 114  	iat := time.Now()
 115  	exp := iat.Add(time.Hour)
 116  	claims := jwt.Claims{
 117  		Iss: g.serviceIdentity,
 118  		Sub: g.serviceIdentity,
 119  		Aud: g.tokenURL,
 120  		Iat: iat.Unix(),
 121  		Exp: exp.Unix(),
 122  	}
 123  	h := jwt.Header{
 124  		Algorithm: jwt.HeaderAlgRSA256,
 125  		Type:      jwt.HeaderType,
 126  		KeyID:     string(g.pkID),
 127  	}
 128  	payload, err := jwt.EncodeJWS(&h, &claims, g.signer)
 129  	if err != nil {
 130  		return nil, err
 131  	}
 132  	v := url.Values{}
 133  	v.Set("grant_type", GrantType)
 134  	v.Set("audience", g.aud)
 135  	v.Set("requested_token_type", requestTokenType)
 136  	v.Set("subject_token", payload)
 137  	v.Set("subject_token_type", subjectTokenType)
 138  
 139  	req, err := http.NewRequestWithContext(ctx, "POST", g.tokenURL, strings.NewReader(v.Encode()))
 140  	if err != nil {
 141  		return nil, err
 142  	}
 143  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 144  	g.logger.DebugContext(ctx, "gdch token request", "request", internallog.HTTPRequest(req, []byte(v.Encode())))
 145  	resp, body, err := internal.DoRequest(g.client, req)
 146  	if err != nil {
 147  		return nil, fmt.Errorf("credentials: cannot fetch token: %w", err)
 148  	}
 149  	g.logger.DebugContext(ctx, "gdch token response", "response", internallog.HTTPResponse(resp, body))
 150  	if c := resp.StatusCode; c < http.StatusOK || c > http.StatusMultipleChoices {
 151  		return nil, &auth.Error{
 152  			Response: resp,
 153  			Body:     body,
 154  		}
 155  	}
 156  
 157  	var tokenRes struct {
 158  		AccessToken string `json:"access_token"`
 159  		TokenType   string `json:"token_type"`
 160  		ExpiresIn   int64  `json:"expires_in"` // relative seconds from now
 161  	}
 162  	if err := json.Unmarshal(body, &tokenRes); err != nil {
 163  		return nil, fmt.Errorf("credentials: cannot fetch token: %w", err)
 164  	}
 165  	token := &auth.Token{
 166  		Value: tokenRes.AccessToken,
 167  		Type:  tokenRes.TokenType,
 168  	}
 169  	raw := make(map[string]interface{})
 170  	json.Unmarshal(body, &raw) // no error checks for optional fields
 171  	token.Metadata = raw
 172  
 173  	if secs := tokenRes.ExpiresIn; secs > 0 {
 174  		token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
 175  	}
 176  	return token, nil
 177  }
 178  
 179  // addCertToTransport makes a best effort attempt at adding in the cert info to
 180  // the client. It tries to keep all configured transport settings if the
 181  // underlying transport is an http.Transport. Or else it overwrites the
 182  // transport with defaults adding in the certs.
 183  func addCertToTransport(hc *http.Client, certPool *x509.CertPool) {
 184  	trans, ok := hc.Transport.(*http.Transport)
 185  	if !ok {
 186  		trans = http.DefaultTransport.(*http.Transport).Clone()
 187  	}
 188  	trans.TLSClientConfig = &tls.Config{
 189  		RootCAs: certPool,
 190  	}
 191  }
 192