precheck.go raw

   1  package dns01
   2  
   3  import (
   4  	"fmt"
   5  	"net"
   6  	"strings"
   7  	"time"
   8  
   9  	"github.com/miekg/dns"
  10  )
  11  
  12  // defaultNameserverPort used by authoritative NS.
  13  // This is for tests only.
  14  var defaultNameserverPort = "53"
  15  
  16  // PreCheckFunc checks DNS propagation before notifying ACME that the DNS challenge is ready.
  17  type PreCheckFunc func(fqdn, value string) (bool, error)
  18  
  19  // WrapPreCheckFunc wraps a PreCheckFunc in order to do extra operations before or after
  20  // the main check, put it in a loop, etc.
  21  type WrapPreCheckFunc func(domain, fqdn, value string, check PreCheckFunc) (bool, error)
  22  
  23  // WrapPreCheck Allow to define checks before notifying ACME that the DNS challenge is ready.
  24  func WrapPreCheck(wrap WrapPreCheckFunc) ChallengeOption {
  25  	return func(chlg *Challenge) error {
  26  		chlg.preCheck.checkFunc = wrap
  27  		return nil
  28  	}
  29  }
  30  
  31  // DisableCompletePropagationRequirement obsolete.
  32  //
  33  // Deprecated: use DisableAuthoritativeNssPropagationRequirement instead.
  34  func DisableCompletePropagationRequirement() ChallengeOption {
  35  	return DisableAuthoritativeNssPropagationRequirement()
  36  }
  37  
  38  func DisableAuthoritativeNssPropagationRequirement() ChallengeOption {
  39  	return func(chlg *Challenge) error {
  40  		chlg.preCheck.requireAuthoritativeNssPropagation = false
  41  		return nil
  42  	}
  43  }
  44  
  45  func RecursiveNSsPropagationRequirement() ChallengeOption {
  46  	return func(chlg *Challenge) error {
  47  		chlg.preCheck.requireRecursiveNssPropagation = true
  48  		return nil
  49  	}
  50  }
  51  
  52  func PropagationWait(wait time.Duration, skipCheck bool) ChallengeOption {
  53  	return WrapPreCheck(func(domain, fqdn, value string, check PreCheckFunc) (bool, error) {
  54  		time.Sleep(wait)
  55  
  56  		if skipCheck {
  57  			return true, nil
  58  		}
  59  
  60  		return check(fqdn, value)
  61  	})
  62  }
  63  
  64  type preCheck struct {
  65  	// checks DNS propagation before notifying ACME that the DNS challenge is ready.
  66  	checkFunc WrapPreCheckFunc
  67  
  68  	// require the TXT record to be propagated to all authoritative name servers
  69  	requireAuthoritativeNssPropagation bool
  70  
  71  	// require the TXT record to be propagated to all recursive name servers
  72  	requireRecursiveNssPropagation bool
  73  }
  74  
  75  func newPreCheck() preCheck {
  76  	return preCheck{
  77  		requireAuthoritativeNssPropagation: true,
  78  	}
  79  }
  80  
  81  func (p preCheck) call(domain, fqdn, value string) (bool, error) {
  82  	if p.checkFunc == nil {
  83  		return p.checkDNSPropagation(fqdn, value)
  84  	}
  85  
  86  	return p.checkFunc(domain, fqdn, value, p.checkDNSPropagation)
  87  }
  88  
  89  // checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
  90  func (p preCheck) checkDNSPropagation(fqdn, value string) (bool, error) {
  91  	// Initial attempt to resolve at the recursive NS (require to get CNAME)
  92  	r, err := dnsQuery(fqdn, dns.TypeTXT, recursiveNameservers, true)
  93  	if err != nil {
  94  		return false, fmt.Errorf("initial recursive nameserver: %w", err)
  95  	}
  96  
  97  	if r.Rcode == dns.RcodeSuccess {
  98  		fqdn = updateDomainWithCName(r, fqdn)
  99  	}
 100  
 101  	if p.requireRecursiveNssPropagation {
 102  		_, err = checkNameserversPropagation(fqdn, value, recursiveNameservers, false)
 103  		if err != nil {
 104  			return false, fmt.Errorf("recursive nameservers: %w", err)
 105  		}
 106  	}
 107  
 108  	if !p.requireAuthoritativeNssPropagation {
 109  		return true, nil
 110  	}
 111  
 112  	authoritativeNss, err := lookupNameservers(fqdn)
 113  	if err != nil {
 114  		return false, err
 115  	}
 116  
 117  	found, err := checkNameserversPropagation(fqdn, value, authoritativeNss, true)
 118  	if err != nil {
 119  		return found, fmt.Errorf("authoritative nameservers: %w", err)
 120  	}
 121  
 122  	return found, nil
 123  }
 124  
 125  // checkNameserversPropagation queries each of the given nameservers for the expected TXT record.
 126  func checkNameserversPropagation(fqdn, value string, nameservers []string, addPort bool) (bool, error) {
 127  	for _, ns := range nameservers {
 128  		if addPort {
 129  			ns = net.JoinHostPort(ns, defaultNameserverPort)
 130  		}
 131  
 132  		r, err := dnsQuery(fqdn, dns.TypeTXT, []string{ns}, false)
 133  		if err != nil {
 134  			return false, err
 135  		}
 136  
 137  		if r.Rcode != dns.RcodeSuccess {
 138  			return false, fmt.Errorf("NS %s returned %s for %s", ns, dns.RcodeToString[r.Rcode], fqdn)
 139  		}
 140  
 141  		var records []string
 142  
 143  		var found bool
 144  
 145  		for _, rr := range r.Answer {
 146  			if txt, ok := rr.(*dns.TXT); ok {
 147  				record := strings.Join(txt.Txt, "")
 148  
 149  				records = append(records, record)
 150  				if record == value {
 151  					found = true
 152  					break
 153  				}
 154  			}
 155  		}
 156  
 157  		if !found {
 158  			return false, fmt.Errorf("NS %s did not return the expected TXT record [fqdn: %s, value: %s]: %s", ns, fqdn, value, strings.Join(records, " ,"))
 159  		}
 160  	}
 161  
 162  	return true, nil
 163  }
 164