oidc.go raw

   1  package azuredns
   2  
   3  import (
   4  	"context"
   5  	"encoding/json"
   6  	"errors"
   7  	"fmt"
   8  	"io"
   9  	"net/http"
  10  	"net/url"
  11  	"os"
  12  	"strings"
  13  )
  14  
  15  func checkOIDCConfig(config *Config) error {
  16  	if config.TenantID == "" {
  17  		return errors.New("azuredns: TenantID is missing")
  18  	}
  19  
  20  	if config.ClientID == "" {
  21  		return errors.New("azuredns: ClientID is missing")
  22  	}
  23  
  24  	if config.OIDCToken == "" && config.OIDCTokenFilePath == "" && (config.OIDCRequestURL == "" || config.OIDCRequestToken == "") {
  25  		return errors.New("azuredns: OIDCToken, OIDCTokenFilePath or OIDCRequestURL and OIDCRequestToken must be set")
  26  	}
  27  
  28  	return nil
  29  }
  30  
  31  func getOIDCAssertion(config *Config) func(ctx context.Context) (string, error) {
  32  	return func(ctx context.Context) (string, error) {
  33  		var token string
  34  		if config.OIDCToken != "" {
  35  			token = strings.TrimSpace(config.OIDCToken)
  36  		}
  37  
  38  		if config.OIDCTokenFilePath != "" {
  39  			fileTokenRaw, err := os.ReadFile(config.OIDCTokenFilePath)
  40  			if err != nil {
  41  				return "", fmt.Errorf("azuredns: error retrieving token file with path %s: %w", config.OIDCTokenFilePath, err)
  42  			}
  43  
  44  			fileToken := strings.TrimSpace(string(fileTokenRaw))
  45  			if config.OIDCToken != fileToken {
  46  				return "", fmt.Errorf("azuredns: token file with path %s does not match token from environment variable", config.OIDCTokenFilePath)
  47  			}
  48  
  49  			token = fileToken
  50  		}
  51  
  52  		if token == "" && config.OIDCRequestURL != "" && config.OIDCRequestToken != "" {
  53  			return getOIDCToken(config)
  54  		}
  55  
  56  		return token, nil
  57  	}
  58  }
  59  
  60  func getOIDCToken(config *Config) (string, error) {
  61  	req, err := http.NewRequest(http.MethodGet, config.OIDCRequestURL, http.NoBody)
  62  	if err != nil {
  63  		return "", fmt.Errorf("azuredns: failed to build OIDC request: %w", err)
  64  	}
  65  
  66  	query, err := url.ParseQuery(req.URL.RawQuery)
  67  	if err != nil {
  68  		return "", errors.New("azuredns: cannot parse OIDC request URL query")
  69  	}
  70  
  71  	if query.Get("audience") == "" {
  72  		query.Set("audience", "api://AzureADTokenExchange")
  73  		req.URL.RawQuery = query.Encode()
  74  	}
  75  
  76  	req.Header.Set("Accept", "application/json")
  77  	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", config.OIDCRequestToken))
  78  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  79  
  80  	resp, err := config.HTTPClient.Do(req)
  81  	if err != nil {
  82  		return "", fmt.Errorf("azuredns: cannot request OIDC token: %w", err)
  83  	}
  84  
  85  	defer func() { _ = resp.Body.Close() }()
  86  
  87  	body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
  88  	if err != nil {
  89  		return "", fmt.Errorf("azuredns: cannot parse OIDC token response: %w", err)
  90  	}
  91  
  92  	if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusNoContent {
  93  		return "", fmt.Errorf("azuredns: OIDC token request received HTTP status %d with response: %s", resp.StatusCode, body)
  94  	}
  95  
  96  	var returnedToken struct {
  97  		Count int    `json:"count"`
  98  		Value string `json:"value"`
  99  	}
 100  	if err := json.Unmarshal(body, &returnedToken); err != nil {
 101  		return "", fmt.Errorf("azuredns: cannot unmarshal OIDC token response: %w", err)
 102  	}
 103  
 104  	return returnedToken.Value, nil
 105  }
 106