public.go raw

   1  package azure
   2  
   3  import (
   4  	"context"
   5  	"errors"
   6  	"fmt"
   7  	"net/http"
   8  	"time"
   9  
  10  	"github.com/Azure/azure-sdk-for-go/profiles/latest/dns/mgmt/dns"
  11  	"github.com/Azure/go-autorest/autorest"
  12  	"github.com/Azure/go-autorest/autorest/to"
  13  	"github.com/go-acme/lego/v4/challenge/dns01"
  14  )
  15  
  16  // dnsProviderPublic implements the challenge.Provider interface for Azure Public Zone DNS.
  17  type dnsProviderPublic struct {
  18  	config     *Config
  19  	authorizer autorest.Authorizer
  20  }
  21  
  22  // Timeout returns the timeout and interval to use when checking for DNS propagation.
  23  // Adjusting here to cope with spikes in propagation times.
  24  func (d *dnsProviderPublic) Timeout() (timeout, interval time.Duration) {
  25  	return d.config.PropagationTimeout, d.config.PollingInterval
  26  }
  27  
  28  // Present creates a TXT record to fulfill the dns-01 challenge.
  29  func (d *dnsProviderPublic) Present(domain, token, keyAuth string) error {
  30  	ctx := context.Background()
  31  	info := dns01.GetChallengeInfo(domain, keyAuth)
  32  
  33  	zone, err := d.getHostedZoneID(ctx, info.EffectiveFQDN)
  34  	if err != nil {
  35  		return fmt.Errorf("azure: %w", err)
  36  	}
  37  
  38  	rsc := dns.NewRecordSetsClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID)
  39  	rsc.Authorizer = d.authorizer
  40  
  41  	subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, zone)
  42  	if err != nil {
  43  		return fmt.Errorf("azure: %w", err)
  44  	}
  45  
  46  	// Get existing record set
  47  	rset, err := rsc.Get(ctx, d.config.ResourceGroup, zone, subDomain, dns.TXT)
  48  	if err != nil {
  49  		var detailed autorest.DetailedError
  50  		if !errors.As(err, &detailed) || detailed.StatusCode != http.StatusNotFound {
  51  			return fmt.Errorf("azure: %w", err)
  52  		}
  53  	}
  54  
  55  	// Construct unique TXT records using map
  56  	uniqRecords := map[string]struct{}{info.Value: {}}
  57  
  58  	if rset.RecordSetProperties != nil && rset.TxtRecords != nil {
  59  		for _, txtRecord := range *rset.TxtRecords {
  60  			// Assume Value doesn't contain multiple strings
  61  			values := to.StringSlice(txtRecord.Value)
  62  			if len(values) > 0 {
  63  				uniqRecords[values[0]] = struct{}{}
  64  			}
  65  		}
  66  	}
  67  
  68  	var txtRecords []dns.TxtRecord
  69  	for txt := range uniqRecords {
  70  		txtRecords = append(txtRecords, dns.TxtRecord{Value: &[]string{txt}})
  71  	}
  72  
  73  	rec := dns.RecordSet{
  74  		Name: &subDomain,
  75  		RecordSetProperties: &dns.RecordSetProperties{
  76  			TTL:        to.Int64Ptr(int64(d.config.TTL)),
  77  			TxtRecords: &txtRecords,
  78  		},
  79  	}
  80  
  81  	_, err = rsc.CreateOrUpdate(ctx, d.config.ResourceGroup, zone, subDomain, dns.TXT, rec, "", "")
  82  	if err != nil {
  83  		return fmt.Errorf("azure: %w", err)
  84  	}
  85  
  86  	return nil
  87  }
  88  
  89  // CleanUp removes the TXT record matching the specified parameters.
  90  func (d *dnsProviderPublic) CleanUp(domain, token, keyAuth string) error {
  91  	ctx := context.Background()
  92  	info := dns01.GetChallengeInfo(domain, keyAuth)
  93  
  94  	zone, err := d.getHostedZoneID(ctx, info.EffectiveFQDN)
  95  	if err != nil {
  96  		return fmt.Errorf("azure: %w", err)
  97  	}
  98  
  99  	subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, zone)
 100  	if err != nil {
 101  		return fmt.Errorf("azure: %w", err)
 102  	}
 103  
 104  	rsc := dns.NewRecordSetsClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID)
 105  	rsc.Authorizer = d.authorizer
 106  
 107  	_, err = rsc.Delete(ctx, d.config.ResourceGroup, zone, subDomain, dns.TXT, "")
 108  	if err != nil {
 109  		return fmt.Errorf("azure: %w", err)
 110  	}
 111  
 112  	return nil
 113  }
 114  
 115  // Checks that azure has a zone for this domain name.
 116  func (d *dnsProviderPublic) getHostedZoneID(ctx context.Context, fqdn string) (string, error) {
 117  	if d.config.ZoneName != "" {
 118  		return d.config.ZoneName, nil
 119  	}
 120  
 121  	authZone, err := dns01.FindZoneByFqdn(fqdn)
 122  	if err != nil {
 123  		return "", fmt.Errorf("could not find zone: %w", err)
 124  	}
 125  
 126  	dc := dns.NewZonesClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID)
 127  	dc.Authorizer = d.authorizer
 128  
 129  	zone, err := dc.Get(ctx, d.config.ResourceGroup, dns01.UnFqdn(authZone))
 130  	if err != nil {
 131  		return "", err
 132  	}
 133  
 134  	// zone.Name shouldn't have a trailing dot(.)
 135  	return to.String(zone.Name), nil
 136  }
 137