public.go raw

   1  package azuredns
   2  
   3  import (
   4  	"context"
   5  	"errors"
   6  	"fmt"
   7  	"net/http"
   8  	"time"
   9  
  10  	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
  11  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
  12  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
  13  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
  14  	"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns"
  15  	"github.com/go-acme/lego/v4/challenge"
  16  	"github.com/go-acme/lego/v4/challenge/dns01"
  17  	"github.com/go-acme/lego/v4/providers/dns/internal/ptr"
  18  )
  19  
  20  var _ challenge.ProviderTimeout = (*DNSProviderPublic)(nil)
  21  
  22  // DNSProviderPublic implements the challenge.Provider interface for Azure Public Zone DNS.
  23  type DNSProviderPublic struct {
  24  	config                *Config
  25  	credentials           azcore.TokenCredential
  26  	serviceDiscoveryZones map[string]ServiceDiscoveryZone
  27  }
  28  
  29  // NewDNSProviderPublic creates a DNSProviderPublic structure.
  30  func NewDNSProviderPublic(config *Config, credentials azcore.TokenCredential) (*DNSProviderPublic, error) {
  31  	zones, err := discoverDNSZones(context.Background(), config, credentials)
  32  	if err != nil {
  33  		return nil, fmt.Errorf("discover DNS zones: %w", err)
  34  	}
  35  
  36  	return &DNSProviderPublic{
  37  		config:                config,
  38  		credentials:           credentials,
  39  		serviceDiscoveryZones: zones,
  40  	}, nil
  41  }
  42  
  43  // Timeout returns the timeout and interval to use when checking for DNS propagation.
  44  // Adjusting here to cope with spikes in propagation times.
  45  func (d *DNSProviderPublic) Timeout() (timeout, interval time.Duration) {
  46  	return d.config.PropagationTimeout, d.config.PollingInterval
  47  }
  48  
  49  // Present creates a TXT record to fulfill the dns-01 challenge.
  50  func (d *DNSProviderPublic) Present(domain, _, keyAuth string) error {
  51  	ctx := context.Background()
  52  	info := dns01.GetChallengeInfo(domain, keyAuth)
  53  
  54  	zone, err := d.getHostedZone(info.EffectiveFQDN)
  55  	if err != nil {
  56  		return fmt.Errorf("azuredns: %w", err)
  57  	}
  58  
  59  	client, err := newPublicZoneClient(zone, d.credentials, d.config.Environment)
  60  	if err != nil {
  61  		return fmt.Errorf("azuredns: %w", err)
  62  	}
  63  
  64  	subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, zone.Name)
  65  	if err != nil {
  66  		return fmt.Errorf("azuredns: %w", err)
  67  	}
  68  
  69  	// Get existing record set
  70  	resp, err := client.Get(ctx, subDomain)
  71  	if err != nil {
  72  		var respErr *azcore.ResponseError
  73  		if !errors.As(err, &respErr) || respErr.StatusCode != http.StatusNotFound {
  74  			return fmt.Errorf("azuredns: %w", err)
  75  		}
  76  	}
  77  
  78  	uniqRecords := publicUniqueRecords(resp.RecordSet, info.Value)
  79  
  80  	var txtRecords []*armdns.TxtRecord
  81  	for txt := range uniqRecords {
  82  		txtRecords = append(txtRecords, &armdns.TxtRecord{Value: to.SliceOfPtrs(txt)})
  83  	}
  84  
  85  	rec := armdns.RecordSet{
  86  		Name: &subDomain,
  87  		Properties: &armdns.RecordSetProperties{
  88  			TTL:        to.Ptr(int64(d.config.TTL)),
  89  			TxtRecords: txtRecords,
  90  		},
  91  	}
  92  
  93  	_, err = client.CreateOrUpdate(ctx, subDomain, rec)
  94  	if err != nil {
  95  		return fmt.Errorf("azuredns: %w", err)
  96  	}
  97  
  98  	return nil
  99  }
 100  
 101  // CleanUp removes the TXT record matching the specified parameters.
 102  func (d *DNSProviderPublic) CleanUp(domain, _, keyAuth string) error {
 103  	ctx := context.Background()
 104  	info := dns01.GetChallengeInfo(domain, keyAuth)
 105  
 106  	zone, err := d.getHostedZone(info.EffectiveFQDN)
 107  	if err != nil {
 108  		return fmt.Errorf("azuredns: %w", err)
 109  	}
 110  
 111  	client, err := newPublicZoneClient(zone, d.credentials, d.config.Environment)
 112  	if err != nil {
 113  		return fmt.Errorf("azuredns: %w", err)
 114  	}
 115  
 116  	subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, zone.Name)
 117  	if err != nil {
 118  		return fmt.Errorf("azuredns: %w", err)
 119  	}
 120  
 121  	_, err = client.Delete(ctx, subDomain)
 122  	if err != nil {
 123  		return fmt.Errorf("azuredns: %w", err)
 124  	}
 125  
 126  	return nil
 127  }
 128  
 129  // Checks that azure has a zone for this domain name.
 130  func (d *DNSProviderPublic) getHostedZone(fqdn string) (ServiceDiscoveryZone, error) {
 131  	authZone, err := getZoneName(d.config, fqdn)
 132  	if err != nil {
 133  		return ServiceDiscoveryZone{}, err
 134  	}
 135  
 136  	azureZone, exists := d.serviceDiscoveryZones[dns01.UnFqdn(authZone)]
 137  	if !exists {
 138  		return ServiceDiscoveryZone{}, fmt.Errorf("could not find zone (from discovery): %s", authZone)
 139  	}
 140  
 141  	return azureZone, nil
 142  }
 143  
 144  type publicZoneClient struct {
 145  	zone         ServiceDiscoveryZone
 146  	recordClient *armdns.RecordSetsClient
 147  }
 148  
 149  // newPublicZoneClient creates publicZoneClient structure with initialized Azure client.
 150  func newPublicZoneClient(zone ServiceDiscoveryZone, credential azcore.TokenCredential, environment cloud.Configuration) (*publicZoneClient, error) {
 151  	options := &arm.ClientOptions{
 152  		ClientOptions: azcore.ClientOptions{
 153  			Cloud: environment,
 154  		},
 155  	}
 156  
 157  	recordClient, err := armdns.NewRecordSetsClient(zone.SubscriptionID, credential, options)
 158  	if err != nil {
 159  		return nil, err
 160  	}
 161  
 162  	return &publicZoneClient{
 163  		zone:         zone,
 164  		recordClient: recordClient,
 165  	}, nil
 166  }
 167  
 168  func (c publicZoneClient) Get(ctx context.Context, subDomain string) (armdns.RecordSetsClientGetResponse, error) {
 169  	return c.recordClient.Get(ctx, c.zone.ResourceGroup, c.zone.Name, subDomain, armdns.RecordTypeTXT, nil)
 170  }
 171  
 172  func (c publicZoneClient) CreateOrUpdate(ctx context.Context, subDomain string, rec armdns.RecordSet) (armdns.RecordSetsClientCreateOrUpdateResponse, error) {
 173  	return c.recordClient.CreateOrUpdate(ctx, c.zone.ResourceGroup, c.zone.Name, subDomain, armdns.RecordTypeTXT, rec, nil)
 174  }
 175  
 176  func (c publicZoneClient) Delete(ctx context.Context, subDomain string) (armdns.RecordSetsClientDeleteResponse, error) {
 177  	return c.recordClient.Delete(ctx, c.zone.ResourceGroup, c.zone.Name, subDomain, armdns.RecordTypeTXT, nil)
 178  }
 179  
 180  func publicUniqueRecords(recordSet armdns.RecordSet, value string) map[string]struct{} {
 181  	uniqRecords := map[string]struct{}{value: {}}
 182  
 183  	if recordSet.Properties != nil && recordSet.Properties.TxtRecords != nil {
 184  		for _, txtRecord := range recordSet.Properties.TxtRecords {
 185  			// Assume Value doesn't contain multiple strings
 186  			if len(txtRecord.Value) > 0 {
 187  				uniqRecords[ptr.Deref(txtRecord.Value[0])] = struct{}{}
 188  			}
 189  		}
 190  	}
 191  
 192  	return uniqRecords
 193  }
 194