oidc.go raw

   1  package providers
   2  
   3  import (
   4  	"encoding/json"
   5  	"errors"
   6  	"fmt"
   7  	"io/ioutil"
   8  	"net/http"
   9  	"os"
  10  	"strconv"
  11  	"strings"
  12  	"time"
  13  
  14  	httputil "github.com/aliyun/credentials-go/credentials/internal/http"
  15  	"github.com/aliyun/credentials-go/credentials/internal/utils"
  16  )
  17  
  18  type OIDCCredentialsProvider struct {
  19  	oidcProviderARN   string
  20  	oidcTokenFilePath string
  21  	roleArn           string
  22  	roleSessionName   string
  23  	durationSeconds   int
  24  	policy            string
  25  	// for sts endpoint
  26  	stsRegionId string
  27  	enableVpc   bool
  28  	stsEndpoint string
  29  
  30  	lastUpdateTimestamp int64
  31  	expirationTimestamp int64
  32  	sessionCredentials  *sessionCredentials
  33  	// for http options
  34  	httpOptions *HttpOptions
  35  }
  36  
  37  type OIDCCredentialsProviderBuilder struct {
  38  	provider *OIDCCredentialsProvider
  39  }
  40  
  41  func NewOIDCCredentialsProviderBuilder() *OIDCCredentialsProviderBuilder {
  42  	return &OIDCCredentialsProviderBuilder{
  43  		provider: &OIDCCredentialsProvider{},
  44  	}
  45  }
  46  
  47  func (b *OIDCCredentialsProviderBuilder) WithOIDCProviderARN(oidcProviderArn string) *OIDCCredentialsProviderBuilder {
  48  	b.provider.oidcProviderARN = oidcProviderArn
  49  	return b
  50  }
  51  
  52  func (b *OIDCCredentialsProviderBuilder) WithOIDCTokenFilePath(oidcTokenFilePath string) *OIDCCredentialsProviderBuilder {
  53  	b.provider.oidcTokenFilePath = oidcTokenFilePath
  54  	return b
  55  }
  56  
  57  func (b *OIDCCredentialsProviderBuilder) WithRoleArn(roleArn string) *OIDCCredentialsProviderBuilder {
  58  	b.provider.roleArn = roleArn
  59  	return b
  60  }
  61  
  62  func (b *OIDCCredentialsProviderBuilder) WithRoleSessionName(roleSessionName string) *OIDCCredentialsProviderBuilder {
  63  	b.provider.roleSessionName = roleSessionName
  64  	return b
  65  }
  66  
  67  func (b *OIDCCredentialsProviderBuilder) WithDurationSeconds(durationSeconds int) *OIDCCredentialsProviderBuilder {
  68  	b.provider.durationSeconds = durationSeconds
  69  	return b
  70  }
  71  
  72  func (b *OIDCCredentialsProviderBuilder) WithStsRegionId(regionId string) *OIDCCredentialsProviderBuilder {
  73  	b.provider.stsRegionId = regionId
  74  	return b
  75  }
  76  
  77  func (b *OIDCCredentialsProviderBuilder) WithEnableVpc(enableVpc bool) *OIDCCredentialsProviderBuilder {
  78  	b.provider.enableVpc = enableVpc
  79  	return b
  80  }
  81  
  82  func (b *OIDCCredentialsProviderBuilder) WithPolicy(policy string) *OIDCCredentialsProviderBuilder {
  83  	b.provider.policy = policy
  84  	return b
  85  }
  86  
  87  func (b *OIDCCredentialsProviderBuilder) WithSTSEndpoint(stsEndpoint string) *OIDCCredentialsProviderBuilder {
  88  	b.provider.stsEndpoint = stsEndpoint
  89  	return b
  90  }
  91  
  92  func (b *OIDCCredentialsProviderBuilder) WithHttpOptions(httpOptions *HttpOptions) *OIDCCredentialsProviderBuilder {
  93  	b.provider.httpOptions = httpOptions
  94  	return b
  95  }
  96  
  97  func (b *OIDCCredentialsProviderBuilder) Build() (provider *OIDCCredentialsProvider, err error) {
  98  	if b.provider.roleSessionName == "" {
  99  		b.provider.roleSessionName = "credentials-go-" + strconv.FormatInt(time.Now().UnixNano()/1000, 10)
 100  	}
 101  
 102  	if b.provider.oidcTokenFilePath == "" {
 103  		b.provider.oidcTokenFilePath = os.Getenv("ALIBABA_CLOUD_OIDC_TOKEN_FILE")
 104  	}
 105  
 106  	if b.provider.oidcTokenFilePath == "" {
 107  		err = errors.New("the OIDCTokenFilePath is empty")
 108  		return
 109  	}
 110  
 111  	if b.provider.oidcProviderARN == "" {
 112  		b.provider.oidcProviderARN = os.Getenv("ALIBABA_CLOUD_OIDC_PROVIDER_ARN")
 113  	}
 114  
 115  	if b.provider.oidcProviderARN == "" {
 116  		err = errors.New("the OIDCProviderARN is empty")
 117  		return
 118  	}
 119  
 120  	if b.provider.roleArn == "" {
 121  		b.provider.roleArn = os.Getenv("ALIBABA_CLOUD_ROLE_ARN")
 122  	}
 123  
 124  	if b.provider.roleArn == "" {
 125  		err = errors.New("the RoleArn is empty")
 126  		return
 127  	}
 128  
 129  	if b.provider.durationSeconds == 0 {
 130  		b.provider.durationSeconds = 3600
 131  	}
 132  
 133  	if b.provider.durationSeconds < 900 {
 134  		err = errors.New("the Assume Role session duration should be in the range of 15min - max duration seconds")
 135  	}
 136  
 137  	if b.provider.stsEndpoint == "" {
 138  		if !b.provider.enableVpc {
 139  			b.provider.enableVpc = strings.ToLower(os.Getenv("ALIBABA_CLOUD_VPC_ENDPOINT_ENABLED")) == "true"
 140  		}
 141  		prefix := "sts"
 142  		if b.provider.enableVpc {
 143  			prefix = "sts-vpc"
 144  		}
 145  		if b.provider.stsRegionId != "" {
 146  			b.provider.stsEndpoint = fmt.Sprintf("%s.%s.aliyuncs.com", prefix, b.provider.stsRegionId)
 147  		} else if region := os.Getenv("ALIBABA_CLOUD_STS_REGION"); region != "" {
 148  			b.provider.stsEndpoint = fmt.Sprintf("%s.%s.aliyuncs.com", prefix, region)
 149  		} else {
 150  			b.provider.stsEndpoint = "sts.aliyuncs.com"
 151  		}
 152  	}
 153  
 154  	provider = b.provider
 155  	return
 156  }
 157  
 158  func (provider *OIDCCredentialsProvider) getCredentials() (session *sessionCredentials, err error) {
 159  	req := &httputil.Request{
 160  		Method:   "POST",
 161  		Protocol: "https",
 162  		Host:     provider.stsEndpoint,
 163  		Headers:  map[string]string{},
 164  	}
 165  
 166  	connectTimeout := 5 * time.Second
 167  	readTimeout := 10 * time.Second
 168  
 169  	if provider.httpOptions != nil && provider.httpOptions.ConnectTimeout > 0 {
 170  		connectTimeout = time.Duration(provider.httpOptions.ConnectTimeout) * time.Millisecond
 171  	}
 172  	if provider.httpOptions != nil && provider.httpOptions.ReadTimeout > 0 {
 173  		readTimeout = time.Duration(provider.httpOptions.ReadTimeout) * time.Millisecond
 174  	}
 175  	if provider.httpOptions != nil && provider.httpOptions.Proxy != "" {
 176  		req.Proxy = provider.httpOptions.Proxy
 177  	}
 178  	req.ConnectTimeout = connectTimeout
 179  	req.ReadTimeout = readTimeout
 180  
 181  	queries := make(map[string]string)
 182  	queries["Version"] = "2015-04-01"
 183  	queries["Action"] = "AssumeRoleWithOIDC"
 184  	queries["Format"] = "JSON"
 185  	queries["Timestamp"] = utils.GetTimeInFormatISO8601()
 186  	req.Queries = queries
 187  
 188  	bodyForm := make(map[string]string)
 189  	bodyForm["RoleArn"] = provider.roleArn
 190  	bodyForm["OIDCProviderArn"] = provider.oidcProviderARN
 191  	token, err := ioutil.ReadFile(provider.oidcTokenFilePath)
 192  	if err != nil {
 193  		return
 194  	}
 195  
 196  	bodyForm["OIDCToken"] = string(token)
 197  	if provider.policy != "" {
 198  		bodyForm["Policy"] = provider.policy
 199  	}
 200  
 201  	bodyForm["RoleSessionName"] = provider.roleSessionName
 202  	bodyForm["DurationSeconds"] = strconv.Itoa(provider.durationSeconds)
 203  	req.Form = bodyForm
 204  
 205  	// set headers
 206  	req.Headers["Accept-Encoding"] = "identity"
 207  	res, err := httpDo(req)
 208  	if err != nil {
 209  		return
 210  	}
 211  
 212  	if res.StatusCode != http.StatusOK {
 213  		message := "get session token failed: "
 214  		err = errors.New(message + string(res.Body))
 215  		return
 216  	}
 217  	var data assumeRoleResponse
 218  	err = json.Unmarshal(res.Body, &data)
 219  	if err != nil {
 220  		err = fmt.Errorf("get oidc sts token err, json.Unmarshal fail: %s", err.Error())
 221  		return
 222  	}
 223  	if data.Credentials == nil {
 224  		err = fmt.Errorf("get oidc sts token err, fail to get credentials")
 225  		return
 226  	}
 227  
 228  	if data.Credentials.AccessKeyId == nil || data.Credentials.AccessKeySecret == nil || data.Credentials.SecurityToken == nil {
 229  		err = fmt.Errorf("refresh RoleArn sts token err, fail to get credentials")
 230  		return
 231  	}
 232  
 233  	session = &sessionCredentials{
 234  		AccessKeyId:     *data.Credentials.AccessKeyId,
 235  		AccessKeySecret: *data.Credentials.AccessKeySecret,
 236  		SecurityToken:   *data.Credentials.SecurityToken,
 237  		Expiration:      *data.Credentials.Expiration,
 238  	}
 239  	return
 240  }
 241  
 242  func (provider *OIDCCredentialsProvider) needUpdateCredential() (result bool) {
 243  	if provider.expirationTimestamp == 0 {
 244  		return true
 245  	}
 246  
 247  	return provider.expirationTimestamp-time.Now().Unix() <= 180
 248  }
 249  
 250  func (provider *OIDCCredentialsProvider) GetCredentials() (cc *Credentials, err error) {
 251  	if provider.sessionCredentials == nil || provider.needUpdateCredential() {
 252  		sessionCredentials, err1 := provider.getCredentials()
 253  		if err1 != nil {
 254  			return nil, err1
 255  		}
 256  
 257  		provider.sessionCredentials = sessionCredentials
 258  		expirationTime, err2 := time.Parse("2006-01-02T15:04:05Z", sessionCredentials.Expiration)
 259  		if err2 != nil {
 260  			return nil, err2
 261  		}
 262  
 263  		provider.lastUpdateTimestamp = time.Now().Unix()
 264  		provider.expirationTimestamp = expirationTime.Unix()
 265  	}
 266  
 267  	cc = &Credentials{
 268  		AccessKeyId:     provider.sessionCredentials.AccessKeyId,
 269  		AccessKeySecret: provider.sessionCredentials.AccessKeySecret,
 270  		SecurityToken:   provider.sessionCredentials.SecurityToken,
 271  		ProviderName:    provider.GetProviderName(),
 272  	}
 273  	return
 274  }
 275  
 276  func (provider *OIDCCredentialsProvider) GetProviderName() string {
 277  	return "oidc_role_arn"
 278  }
 279