precheck.go raw
1 package dns01
2
3 import (
4 "fmt"
5 "net"
6 "strings"
7 "time"
8
9 "github.com/miekg/dns"
10 )
11
12 // defaultNameserverPort used by authoritative NS.
13 // This is for tests only.
14 var defaultNameserverPort = "53"
15
16 // PreCheckFunc checks DNS propagation before notifying ACME that the DNS challenge is ready.
17 type PreCheckFunc func(fqdn, value string) (bool, error)
18
19 // WrapPreCheckFunc wraps a PreCheckFunc in order to do extra operations before or after
20 // the main check, put it in a loop, etc.
21 type WrapPreCheckFunc func(domain, fqdn, value string, check PreCheckFunc) (bool, error)
22
23 // WrapPreCheck Allow to define checks before notifying ACME that the DNS challenge is ready.
24 func WrapPreCheck(wrap WrapPreCheckFunc) ChallengeOption {
25 return func(chlg *Challenge) error {
26 chlg.preCheck.checkFunc = wrap
27 return nil
28 }
29 }
30
31 // DisableCompletePropagationRequirement obsolete.
32 //
33 // Deprecated: use DisableAuthoritativeNssPropagationRequirement instead.
34 func DisableCompletePropagationRequirement() ChallengeOption {
35 return DisableAuthoritativeNssPropagationRequirement()
36 }
37
38 func DisableAuthoritativeNssPropagationRequirement() ChallengeOption {
39 return func(chlg *Challenge) error {
40 chlg.preCheck.requireAuthoritativeNssPropagation = false
41 return nil
42 }
43 }
44
45 func RecursiveNSsPropagationRequirement() ChallengeOption {
46 return func(chlg *Challenge) error {
47 chlg.preCheck.requireRecursiveNssPropagation = true
48 return nil
49 }
50 }
51
52 func PropagationWait(wait time.Duration, skipCheck bool) ChallengeOption {
53 return WrapPreCheck(func(domain, fqdn, value string, check PreCheckFunc) (bool, error) {
54 time.Sleep(wait)
55
56 if skipCheck {
57 return true, nil
58 }
59
60 return check(fqdn, value)
61 })
62 }
63
64 type preCheck struct {
65 // checks DNS propagation before notifying ACME that the DNS challenge is ready.
66 checkFunc WrapPreCheckFunc
67
68 // require the TXT record to be propagated to all authoritative name servers
69 requireAuthoritativeNssPropagation bool
70
71 // require the TXT record to be propagated to all recursive name servers
72 requireRecursiveNssPropagation bool
73 }
74
75 func newPreCheck() preCheck {
76 return preCheck{
77 requireAuthoritativeNssPropagation: true,
78 }
79 }
80
81 func (p preCheck) call(domain, fqdn, value string) (bool, error) {
82 if p.checkFunc == nil {
83 return p.checkDNSPropagation(fqdn, value)
84 }
85
86 return p.checkFunc(domain, fqdn, value, p.checkDNSPropagation)
87 }
88
89 // checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
90 func (p preCheck) checkDNSPropagation(fqdn, value string) (bool, error) {
91 // Initial attempt to resolve at the recursive NS (require to get CNAME)
92 r, err := dnsQuery(fqdn, dns.TypeTXT, recursiveNameservers, true)
93 if err != nil {
94 return false, fmt.Errorf("initial recursive nameserver: %w", err)
95 }
96
97 if r.Rcode == dns.RcodeSuccess {
98 fqdn = updateDomainWithCName(r, fqdn)
99 }
100
101 if p.requireRecursiveNssPropagation {
102 _, err = checkNameserversPropagation(fqdn, value, recursiveNameservers, false)
103 if err != nil {
104 return false, fmt.Errorf("recursive nameservers: %w", err)
105 }
106 }
107
108 if !p.requireAuthoritativeNssPropagation {
109 return true, nil
110 }
111
112 authoritativeNss, err := lookupNameservers(fqdn)
113 if err != nil {
114 return false, err
115 }
116
117 found, err := checkNameserversPropagation(fqdn, value, authoritativeNss, true)
118 if err != nil {
119 return found, fmt.Errorf("authoritative nameservers: %w", err)
120 }
121
122 return found, nil
123 }
124
125 // checkNameserversPropagation queries each of the given nameservers for the expected TXT record.
126 func checkNameserversPropagation(fqdn, value string, nameservers []string, addPort bool) (bool, error) {
127 for _, ns := range nameservers {
128 if addPort {
129 ns = net.JoinHostPort(ns, defaultNameserverPort)
130 }
131
132 r, err := dnsQuery(fqdn, dns.TypeTXT, []string{ns}, false)
133 if err != nil {
134 return false, err
135 }
136
137 if r.Rcode != dns.RcodeSuccess {
138 return false, fmt.Errorf("NS %s returned %s for %s", ns, dns.RcodeToString[r.Rcode], fqdn)
139 }
140
141 var records []string
142
143 var found bool
144
145 for _, rr := range r.Answer {
146 if txt, ok := rr.(*dns.TXT); ok {
147 record := strings.Join(txt.Txt, "")
148
149 records = append(records, record)
150 if record == value {
151 found = true
152 break
153 }
154 }
155 }
156
157 if !found {
158 return false, fmt.Errorf("NS %s did not return the expected TXT record [fqdn: %s, value: %s]: %s", ns, fqdn, value, strings.Join(records, " ,"))
159 }
160 }
161
162 return true, nil
163 }
164