credentials.go raw
1 package azuredns
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "time"
9
10 "github.com/Azure/azure-sdk-for-go/sdk/azcore"
11 "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
12 "github.com/Azure/azure-sdk-for-go/sdk/azidentity"
13 "github.com/go-acme/lego/v4/challenge/dns01"
14 )
15
16 const (
17 authMethodEnv = "env"
18 authMethodWLI = "wli"
19 authMethodMSI = "msi"
20 authMethodCLI = "cli"
21 authMethodOIDC = "oidc"
22 authMethodPipeline = "pipeline"
23 )
24
25 //nolint:gocyclo // The complexity is related to the number of possible configurations.
26 func getCredentials(config *Config) (azcore.TokenCredential, error) {
27 clientOptions := azcore.ClientOptions{Cloud: config.Environment}
28
29 switch strings.ToLower(config.AuthMethod) {
30 case authMethodEnv:
31 if config.ClientID != "" && config.ClientSecret != "" && config.TenantID != "" {
32 return azidentity.NewClientSecretCredential(config.TenantID, config.ClientID, config.ClientSecret,
33 &azidentity.ClientSecretCredentialOptions{ClientOptions: clientOptions})
34 }
35
36 return azidentity.NewEnvironmentCredential(&azidentity.EnvironmentCredentialOptions{ClientOptions: clientOptions})
37
38 case authMethodWLI:
39 return azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{ClientOptions: clientOptions})
40
41 case authMethodMSI:
42 cred, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ClientOptions: clientOptions})
43 if err != nil {
44 return nil, err
45 }
46
47 return &timeoutTokenCredential{cred: cred, timeout: config.AuthMSITimeout}, nil
48
49 case authMethodCLI:
50 var credOptions *azidentity.AzureCLICredentialOptions
51 if config.TenantID != "" {
52 credOptions = &azidentity.AzureCLICredentialOptions{TenantID: config.TenantID}
53 }
54
55 return azidentity.NewAzureCLICredential(credOptions)
56
57 case authMethodOIDC:
58 err := checkOIDCConfig(config)
59 if err != nil {
60 return nil, err
61 }
62
63 return azidentity.NewClientAssertionCredential(config.TenantID, config.ClientID, getOIDCAssertion(config), &azidentity.ClientAssertionCredentialOptions{ClientOptions: clientOptions})
64
65 case authMethodPipeline:
66 err := checkPipelineConfig(config)
67 if err != nil {
68 return nil, err
69 }
70
71 // Uses the env var `SYSTEM_OIDCREQUESTURI`,
72 // but the constant is not exported,
73 // and there is no way to set it programmatically.
74 // https://github.com/Azure/azure-sdk-for-go/blob/aae2fb75ffccafc669db72bebc3c1a66332f48d7/sdk/azidentity/azure_pipelines_credential.go#L22
75 // https://github.com/Azure/azure-sdk-for-go/blob/aae2fb75ffccafc669db72bebc3c1a66332f48d7/sdk/azidentity/azure_pipelines_credential.go#L79
76
77 return azidentity.NewAzurePipelinesCredential(config.TenantID, config.ClientID, config.ServiceConnectionID, config.SystemAccessToken, &azidentity.AzurePipelinesCredentialOptions{ClientOptions: clientOptions})
78
79 default:
80 return azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{ClientOptions: clientOptions})
81 }
82 }
83
84 // timeoutTokenCredential wraps a TokenCredential to add a timeout.
85 type timeoutTokenCredential struct {
86 cred azcore.TokenCredential
87 timeout time.Duration
88 }
89
90 // GetToken implements the azcore.TokenCredential interface.
91 func (w *timeoutTokenCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
92 if w.timeout <= 0 {
93 return w.cred.GetToken(ctx, opts)
94 }
95
96 ctxTimeout, cancel := context.WithTimeout(ctx, w.timeout)
97 defer cancel()
98
99 tk, err := w.cred.GetToken(ctxTimeout, opts)
100 if ce := ctxTimeout.Err(); errors.Is(ce, context.DeadlineExceeded) {
101 return tk, azidentity.NewCredentialUnavailableError("managed identity timed out")
102 }
103
104 w.timeout = 0
105
106 return tk, err
107 }
108
109 func getZoneName(config *Config, fqdn string) (string, error) {
110 if config.ZoneName != "" {
111 return config.ZoneName, nil
112 }
113
114 authZone, err := dns01.FindZoneByFqdn(fqdn)
115 if err != nil {
116 return "", fmt.Errorf("could not find zone for %s: %w", fqdn, err)
117 }
118
119 if authZone == "" {
120 return "", errors.New("empty zone name")
121 }
122
123 return authZone, nil
124 }
125
126 func checkPipelineConfig(config *Config) error {
127 if config.ServiceConnectionID == "" {
128 return errors.New("azuredns: ServiceConnectionID is missing")
129 }
130
131 if config.SystemAccessToken == "" {
132 return errors.New("azuredns: SystemAccessToken is missing")
133 }
134
135 return nil
136 }
137