sso_cached_token.go raw

   1  package ssocreds
   2  
   3  import (
   4  	"crypto/sha1"
   5  	"encoding/hex"
   6  	"encoding/json"
   7  	"fmt"
   8  	"io/ioutil"
   9  	"os"
  10  	"path/filepath"
  11  	"strconv"
  12  	"strings"
  13  	"time"
  14  
  15  	"github.com/aws/aws-sdk-go-v2/internal/sdk"
  16  	"github.com/aws/aws-sdk-go-v2/internal/shareddefaults"
  17  )
  18  
  19  var osUserHomeDur = shareddefaults.UserHomeDir
  20  
  21  // StandardCachedTokenFilepath returns the filepath for the cached SSO token file, or
  22  // error if unable get derive the path. Key that will be used to compute a SHA1
  23  // value that is hex encoded.
  24  //
  25  // Derives the filepath using the Key as:
  26  //
  27  //	~/.aws/sso/cache/<sha1-hex-encoded-key>.json
  28  func StandardCachedTokenFilepath(key string) (string, error) {
  29  	homeDir := osUserHomeDur()
  30  	if len(homeDir) == 0 {
  31  		return "", fmt.Errorf("unable to get USER's home directory for cached token")
  32  	}
  33  	hash := sha1.New()
  34  	if _, err := hash.Write([]byte(key)); err != nil {
  35  		return "", fmt.Errorf("unable to compute cached token filepath key SHA1 hash, %w", err)
  36  	}
  37  
  38  	cacheFilename := strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json"
  39  
  40  	return filepath.Join(homeDir, ".aws", "sso", "cache", cacheFilename), nil
  41  }
  42  
  43  type tokenKnownFields struct {
  44  	AccessToken string   `json:"accessToken,omitempty"`
  45  	ExpiresAt   *rfc3339 `json:"expiresAt,omitempty"`
  46  
  47  	RefreshToken string `json:"refreshToken,omitempty"`
  48  	ClientID     string `json:"clientId,omitempty"`
  49  	ClientSecret string `json:"clientSecret,omitempty"`
  50  }
  51  
  52  type token struct {
  53  	tokenKnownFields
  54  	UnknownFields map[string]interface{} `json:"-"`
  55  }
  56  
  57  func (t token) MarshalJSON() ([]byte, error) {
  58  	fields := map[string]interface{}{}
  59  
  60  	setTokenFieldString(fields, "accessToken", t.AccessToken)
  61  	setTokenFieldRFC3339(fields, "expiresAt", t.ExpiresAt)
  62  
  63  	setTokenFieldString(fields, "refreshToken", t.RefreshToken)
  64  	setTokenFieldString(fields, "clientId", t.ClientID)
  65  	setTokenFieldString(fields, "clientSecret", t.ClientSecret)
  66  
  67  	for k, v := range t.UnknownFields {
  68  		if _, ok := fields[k]; ok {
  69  			return nil, fmt.Errorf("unknown token field %v, duplicates known field", k)
  70  		}
  71  		fields[k] = v
  72  	}
  73  
  74  	return json.Marshal(fields)
  75  }
  76  
  77  func setTokenFieldString(fields map[string]interface{}, key, value string) {
  78  	if value == "" {
  79  		return
  80  	}
  81  	fields[key] = value
  82  }
  83  func setTokenFieldRFC3339(fields map[string]interface{}, key string, value *rfc3339) {
  84  	if value == nil {
  85  		return
  86  	}
  87  	fields[key] = value
  88  }
  89  
  90  func (t *token) UnmarshalJSON(b []byte) error {
  91  	var fields map[string]interface{}
  92  	if err := json.Unmarshal(b, &fields); err != nil {
  93  		return nil
  94  	}
  95  
  96  	t.UnknownFields = map[string]interface{}{}
  97  
  98  	for k, v := range fields {
  99  		var err error
 100  		switch k {
 101  		case "accessToken":
 102  			err = getTokenFieldString(v, &t.AccessToken)
 103  		case "expiresAt":
 104  			err = getTokenFieldRFC3339(v, &t.ExpiresAt)
 105  		case "refreshToken":
 106  			err = getTokenFieldString(v, &t.RefreshToken)
 107  		case "clientId":
 108  			err = getTokenFieldString(v, &t.ClientID)
 109  		case "clientSecret":
 110  			err = getTokenFieldString(v, &t.ClientSecret)
 111  		default:
 112  			t.UnknownFields[k] = v
 113  		}
 114  
 115  		if err != nil {
 116  			return fmt.Errorf("field %q, %w", k, err)
 117  		}
 118  	}
 119  
 120  	return nil
 121  }
 122  
 123  func getTokenFieldString(v interface{}, value *string) error {
 124  	var ok bool
 125  	*value, ok = v.(string)
 126  	if !ok {
 127  		return fmt.Errorf("expect value to be string, got %T", v)
 128  	}
 129  	return nil
 130  }
 131  
 132  func getTokenFieldRFC3339(v interface{}, value **rfc3339) error {
 133  	var stringValue string
 134  	if err := getTokenFieldString(v, &stringValue); err != nil {
 135  		return err
 136  	}
 137  
 138  	timeValue, err := parseRFC3339(stringValue)
 139  	if err != nil {
 140  		return err
 141  	}
 142  
 143  	*value = &timeValue
 144  	return nil
 145  }
 146  
 147  func loadCachedToken(filename string) (token, error) {
 148  	fileBytes, err := ioutil.ReadFile(filename)
 149  	if err != nil {
 150  		return token{}, fmt.Errorf("failed to read cached SSO token file, %w", err)
 151  	}
 152  
 153  	var t token
 154  	if err := json.Unmarshal(fileBytes, &t); err != nil {
 155  		return token{}, fmt.Errorf("failed to parse cached SSO token file, %w", err)
 156  	}
 157  
 158  	if len(t.AccessToken) == 0 || t.ExpiresAt == nil || time.Time(*t.ExpiresAt).IsZero() {
 159  		return token{}, fmt.Errorf(
 160  			"cached SSO token must contain accessToken and expiresAt fields")
 161  	}
 162  
 163  	return t, nil
 164  }
 165  
 166  func storeCachedToken(filename string, t token, fileMode os.FileMode) (err error) {
 167  	tmpFilename := filename + ".tmp-" + strconv.FormatInt(sdk.NowTime().UnixNano(), 10)
 168  	if err := writeCacheFile(tmpFilename, fileMode, t); err != nil {
 169  		return err
 170  	}
 171  
 172  	if err := os.Rename(tmpFilename, filename); err != nil {
 173  		return fmt.Errorf("failed to replace old cached SSO token file, %w", err)
 174  	}
 175  
 176  	return nil
 177  }
 178  
 179  func writeCacheFile(filename string, fileMode os.FileMode, t token) (err error) {
 180  	var f *os.File
 181  	f, err = os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_RDWR, fileMode)
 182  	if err != nil {
 183  		return fmt.Errorf("failed to create cached SSO token file %w", err)
 184  	}
 185  
 186  	defer func() {
 187  		closeErr := f.Close()
 188  		if err == nil && closeErr != nil {
 189  			err = fmt.Errorf("failed to close cached SSO token file, %w", closeErr)
 190  		}
 191  	}()
 192  
 193  	encoder := json.NewEncoder(f)
 194  
 195  	if err = encoder.Encode(t); err != nil {
 196  		return fmt.Errorf("failed to serialize cached SSO token, %w", err)
 197  	}
 198  
 199  	return nil
 200  }
 201  
 202  type rfc3339 time.Time
 203  
 204  func parseRFC3339(v string) (rfc3339, error) {
 205  	parsed, err := time.Parse(time.RFC3339, v)
 206  	if err != nil {
 207  		return rfc3339{}, fmt.Errorf("expected RFC3339 timestamp: %w", err)
 208  	}
 209  
 210  	return rfc3339(parsed), nil
 211  }
 212  
 213  func (r *rfc3339) UnmarshalJSON(bytes []byte) (err error) {
 214  	var value string
 215  
 216  	// Use JSON unmarshal to unescape the quoted value making use of JSON's
 217  	// unquoting rules.
 218  	if err = json.Unmarshal(bytes, &value); err != nil {
 219  		return err
 220  	}
 221  
 222  	*r, err = parseRFC3339(value)
 223  
 224  	return nil
 225  }
 226  
 227  func (r *rfc3339) MarshalJSON() ([]byte, error) {
 228  	value := time.Time(*r).UTC().Format(time.RFC3339)
 229  
 230  	// Use JSON unmarshal to unescape the quoted value making use of JSON's
 231  	// quoting rules.
 232  	return json.Marshal(value)
 233  }
 234