public.go raw
1 package azuredns
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "net/http"
8 "time"
9
10 "github.com/Azure/azure-sdk-for-go/sdk/azcore"
11 "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
12 "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
13 "github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
14 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns"
15 "github.com/go-acme/lego/v4/challenge"
16 "github.com/go-acme/lego/v4/challenge/dns01"
17 "github.com/go-acme/lego/v4/providers/dns/internal/ptr"
18 )
19
20 var _ challenge.ProviderTimeout = (*DNSProviderPublic)(nil)
21
22 // DNSProviderPublic implements the challenge.Provider interface for Azure Public Zone DNS.
23 type DNSProviderPublic struct {
24 config *Config
25 credentials azcore.TokenCredential
26 serviceDiscoveryZones map[string]ServiceDiscoveryZone
27 }
28
29 // NewDNSProviderPublic creates a DNSProviderPublic structure.
30 func NewDNSProviderPublic(config *Config, credentials azcore.TokenCredential) (*DNSProviderPublic, error) {
31 zones, err := discoverDNSZones(context.Background(), config, credentials)
32 if err != nil {
33 return nil, fmt.Errorf("discover DNS zones: %w", err)
34 }
35
36 return &DNSProviderPublic{
37 config: config,
38 credentials: credentials,
39 serviceDiscoveryZones: zones,
40 }, nil
41 }
42
43 // Timeout returns the timeout and interval to use when checking for DNS propagation.
44 // Adjusting here to cope with spikes in propagation times.
45 func (d *DNSProviderPublic) Timeout() (timeout, interval time.Duration) {
46 return d.config.PropagationTimeout, d.config.PollingInterval
47 }
48
49 // Present creates a TXT record to fulfill the dns-01 challenge.
50 func (d *DNSProviderPublic) Present(domain, _, keyAuth string) error {
51 ctx := context.Background()
52 info := dns01.GetChallengeInfo(domain, keyAuth)
53
54 zone, err := d.getHostedZone(info.EffectiveFQDN)
55 if err != nil {
56 return fmt.Errorf("azuredns: %w", err)
57 }
58
59 client, err := newPublicZoneClient(zone, d.credentials, d.config.Environment)
60 if err != nil {
61 return fmt.Errorf("azuredns: %w", err)
62 }
63
64 subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, zone.Name)
65 if err != nil {
66 return fmt.Errorf("azuredns: %w", err)
67 }
68
69 // Get existing record set
70 resp, err := client.Get(ctx, subDomain)
71 if err != nil {
72 var respErr *azcore.ResponseError
73 if !errors.As(err, &respErr) || respErr.StatusCode != http.StatusNotFound {
74 return fmt.Errorf("azuredns: %w", err)
75 }
76 }
77
78 uniqRecords := publicUniqueRecords(resp.RecordSet, info.Value)
79
80 var txtRecords []*armdns.TxtRecord
81 for txt := range uniqRecords {
82 txtRecords = append(txtRecords, &armdns.TxtRecord{Value: to.SliceOfPtrs(txt)})
83 }
84
85 rec := armdns.RecordSet{
86 Name: &subDomain,
87 Properties: &armdns.RecordSetProperties{
88 TTL: to.Ptr(int64(d.config.TTL)),
89 TxtRecords: txtRecords,
90 },
91 }
92
93 _, err = client.CreateOrUpdate(ctx, subDomain, rec)
94 if err != nil {
95 return fmt.Errorf("azuredns: %w", err)
96 }
97
98 return nil
99 }
100
101 // CleanUp removes the TXT record matching the specified parameters.
102 func (d *DNSProviderPublic) CleanUp(domain, _, keyAuth string) error {
103 ctx := context.Background()
104 info := dns01.GetChallengeInfo(domain, keyAuth)
105
106 zone, err := d.getHostedZone(info.EffectiveFQDN)
107 if err != nil {
108 return fmt.Errorf("azuredns: %w", err)
109 }
110
111 client, err := newPublicZoneClient(zone, d.credentials, d.config.Environment)
112 if err != nil {
113 return fmt.Errorf("azuredns: %w", err)
114 }
115
116 subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, zone.Name)
117 if err != nil {
118 return fmt.Errorf("azuredns: %w", err)
119 }
120
121 _, err = client.Delete(ctx, subDomain)
122 if err != nil {
123 return fmt.Errorf("azuredns: %w", err)
124 }
125
126 return nil
127 }
128
129 // Checks that azure has a zone for this domain name.
130 func (d *DNSProviderPublic) getHostedZone(fqdn string) (ServiceDiscoveryZone, error) {
131 authZone, err := getZoneName(d.config, fqdn)
132 if err != nil {
133 return ServiceDiscoveryZone{}, err
134 }
135
136 azureZone, exists := d.serviceDiscoveryZones[dns01.UnFqdn(authZone)]
137 if !exists {
138 return ServiceDiscoveryZone{}, fmt.Errorf("could not find zone (from discovery): %s", authZone)
139 }
140
141 return azureZone, nil
142 }
143
144 type publicZoneClient struct {
145 zone ServiceDiscoveryZone
146 recordClient *armdns.RecordSetsClient
147 }
148
149 // newPublicZoneClient creates publicZoneClient structure with initialized Azure client.
150 func newPublicZoneClient(zone ServiceDiscoveryZone, credential azcore.TokenCredential, environment cloud.Configuration) (*publicZoneClient, error) {
151 options := &arm.ClientOptions{
152 ClientOptions: azcore.ClientOptions{
153 Cloud: environment,
154 },
155 }
156
157 recordClient, err := armdns.NewRecordSetsClient(zone.SubscriptionID, credential, options)
158 if err != nil {
159 return nil, err
160 }
161
162 return &publicZoneClient{
163 zone: zone,
164 recordClient: recordClient,
165 }, nil
166 }
167
168 func (c publicZoneClient) Get(ctx context.Context, subDomain string) (armdns.RecordSetsClientGetResponse, error) {
169 return c.recordClient.Get(ctx, c.zone.ResourceGroup, c.zone.Name, subDomain, armdns.RecordTypeTXT, nil)
170 }
171
172 func (c publicZoneClient) CreateOrUpdate(ctx context.Context, subDomain string, rec armdns.RecordSet) (armdns.RecordSetsClientCreateOrUpdateResponse, error) {
173 return c.recordClient.CreateOrUpdate(ctx, c.zone.ResourceGroup, c.zone.Name, subDomain, armdns.RecordTypeTXT, rec, nil)
174 }
175
176 func (c publicZoneClient) Delete(ctx context.Context, subDomain string) (armdns.RecordSetsClientDeleteResponse, error) {
177 return c.recordClient.Delete(ctx, c.zone.ResourceGroup, c.zone.Name, subDomain, armdns.RecordTypeTXT, nil)
178 }
179
180 func publicUniqueRecords(recordSet armdns.RecordSet, value string) map[string]struct{} {
181 uniqRecords := map[string]struct{}{value: {}}
182
183 if recordSet.Properties != nil && recordSet.Properties.TxtRecords != nil {
184 for _, txtRecord := range recordSet.Properties.TxtRecords {
185 // Assume Value doesn't contain multiple strings
186 if len(txtRecord.Value) > 0 {
187 uniqRecords[ptr.Deref(txtRecord.Value[0])] = struct{}{}
188 }
189 }
190 }
191
192 return uniqRecords
193 }
194