nameserver.go raw
1 package dns01
2
3 import (
4 "errors"
5 "fmt"
6 "net"
7 "os"
8 "slices"
9 "strconv"
10 "strings"
11 "sync"
12 "time"
13
14 "github.com/miekg/dns"
15 )
16
17 const defaultResolvConf = "/etc/resolv.conf"
18
19 var fqdnSoaCache = &sync.Map{}
20
21 var defaultNameservers = []string{
22 "google-public-dns-a.google.com:53",
23 "google-public-dns-b.google.com:53",
24 }
25
26 // recursiveNameservers are used to pre-check DNS propagation.
27 var recursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers)
28
29 // soaCacheEntry holds a cached SOA record (only selected fields).
30 type soaCacheEntry struct {
31 zone string // zone apex (a domain name)
32 primaryNs string // primary nameserver for the zone apex
33 expires time.Time // time when this cache entry should be evicted
34 }
35
36 func newSoaCacheEntry(soa *dns.SOA) *soaCacheEntry {
37 return &soaCacheEntry{
38 zone: soa.Hdr.Name,
39 primaryNs: soa.Ns,
40 expires: time.Now().Add(time.Duration(soa.Refresh) * time.Second),
41 }
42 }
43
44 // isExpired checks whether a cache entry should be considered expired.
45 func (cache *soaCacheEntry) isExpired() bool {
46 return time.Now().After(cache.expires)
47 }
48
49 // ClearFqdnCache clears the cache of fqdn to zone mappings. Primarily used in testing.
50 func ClearFqdnCache() {
51 // TODO(ldez): use `fqdnSoaCache.Clear()` when updating to go1.23
52 fqdnSoaCache.Range(func(k, v any) bool {
53 fqdnSoaCache.Delete(k)
54 return true
55 })
56 }
57
58 func AddDNSTimeout(timeout time.Duration) ChallengeOption {
59 return func(_ *Challenge) error {
60 dnsTimeout = timeout
61 return nil
62 }
63 }
64
65 func AddRecursiveNameservers(nameservers []string) ChallengeOption {
66 return func(_ *Challenge) error {
67 recursiveNameservers = ParseNameservers(nameservers)
68 return nil
69 }
70 }
71
72 // getNameservers attempts to get systems nameservers before falling back to the defaults.
73 func getNameservers(path string, defaults []string) []string {
74 config, err := dns.ClientConfigFromFile(path)
75 if err != nil || len(config.Servers) == 0 {
76 return defaults
77 }
78
79 return ParseNameservers(config.Servers)
80 }
81
82 func ParseNameservers(servers []string) []string {
83 var resolvers []string
84
85 for _, resolver := range servers {
86 // ensure all servers have a port number
87 if _, _, err := net.SplitHostPort(resolver); err != nil {
88 resolvers = append(resolvers, net.JoinHostPort(resolver, "53"))
89 } else {
90 resolvers = append(resolvers, resolver)
91 }
92 }
93
94 return resolvers
95 }
96
97 // lookupNameservers returns the authoritative nameservers for the given fqdn.
98 func lookupNameservers(fqdn string) ([]string, error) {
99 var authoritativeNss []string
100
101 zone, err := FindZoneByFqdn(fqdn)
102 if err != nil {
103 return nil, fmt.Errorf("could not find zone: %w", err)
104 }
105
106 r, err := dnsQuery(zone, dns.TypeNS, recursiveNameservers, true)
107 if err != nil {
108 return nil, fmt.Errorf("NS call failed: %w", err)
109 }
110
111 for _, rr := range r.Answer {
112 if ns, ok := rr.(*dns.NS); ok {
113 authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
114 }
115 }
116
117 if len(authoritativeNss) > 0 {
118 return authoritativeNss, nil
119 }
120
121 return nil, fmt.Errorf("[zone=%s] could not determine authoritative nameservers", zone)
122 }
123
124 // FindPrimaryNsByFqdn determines the primary nameserver of the zone apex for the given fqdn
125 // by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
126 func FindPrimaryNsByFqdn(fqdn string) (string, error) {
127 return FindPrimaryNsByFqdnCustom(fqdn, recursiveNameservers)
128 }
129
130 // FindPrimaryNsByFqdnCustom determines the primary nameserver of the zone apex for the given fqdn
131 // by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
132 func FindPrimaryNsByFqdnCustom(fqdn string, nameservers []string) (string, error) {
133 soa, err := lookupSoaByFqdn(fqdn, nameservers)
134 if err != nil {
135 return "", fmt.Errorf("[fqdn=%s] %w", fqdn, err)
136 }
137
138 return soa.primaryNs, nil
139 }
140
141 // FindZoneByFqdn determines the zone apex for the given fqdn
142 // by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
143 func FindZoneByFqdn(fqdn string) (string, error) {
144 return FindZoneByFqdnCustom(fqdn, recursiveNameservers)
145 }
146
147 // FindZoneByFqdnCustom determines the zone apex for the given fqdn
148 // by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
149 func FindZoneByFqdnCustom(fqdn string, nameservers []string) (string, error) {
150 soa, err := lookupSoaByFqdn(fqdn, nameservers)
151 if err != nil {
152 return "", fmt.Errorf("[fqdn=%s] %w", fqdn, err)
153 }
154
155 return soa.zone, nil
156 }
157
158 func lookupSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) {
159 // Do we have it cached and is it still fresh?
160 entAny, ok := fqdnSoaCache.Load(fqdn)
161 if ok && entAny != nil {
162 ent, ok1 := entAny.(*soaCacheEntry)
163 if ok1 && !ent.isExpired() {
164 return ent, nil
165 }
166 }
167
168 ent, err := fetchSoaByFqdn(fqdn, nameservers)
169 if err != nil {
170 return nil, err
171 }
172
173 fqdnSoaCache.Store(fqdn, ent)
174
175 return ent, nil
176 }
177
178 func fetchSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) {
179 var (
180 err error
181 r *dns.Msg
182 )
183
184 for domain := range DomainsSeq(fqdn) {
185 r, err = dnsQuery(domain, dns.TypeSOA, nameservers, true)
186 if err != nil {
187 continue
188 }
189
190 if r == nil {
191 continue
192 }
193
194 switch r.Rcode {
195 case dns.RcodeSuccess:
196 // Check if we got a SOA RR in the answer section
197 if len(r.Answer) == 0 {
198 continue
199 }
200
201 // CNAME records cannot/should not exist at the root of a zone.
202 // So we skip a domain when a CNAME is found.
203 if dnsMsgContainsCNAME(r) {
204 continue
205 }
206
207 for _, ans := range r.Answer {
208 if soa, ok := ans.(*dns.SOA); ok {
209 return newSoaCacheEntry(soa), nil
210 }
211 }
212 case dns.RcodeNameError:
213 // NXDOMAIN
214 default:
215 // Any response code other than NOERROR and NXDOMAIN is treated as error
216 return nil, &DNSError{Message: fmt.Sprintf("unexpected response for '%s'", domain), MsgOut: r}
217 }
218 }
219
220 return nil, &DNSError{Message: fmt.Sprintf("could not find the start of authority for '%s'", fqdn), MsgOut: r, Err: err}
221 }
222
223 // dnsMsgContainsCNAME checks for a CNAME answer in msg.
224 func dnsMsgContainsCNAME(msg *dns.Msg) bool {
225 return slices.ContainsFunc(msg.Answer, func(rr dns.RR) bool {
226 _, ok := rr.(*dns.CNAME)
227 return ok
228 })
229 }
230
231 func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
232 m := createDNSMsg(fqdn, rtype, recursive)
233
234 if len(nameservers) == 0 {
235 return nil, &DNSError{Message: "empty list of nameservers"}
236 }
237
238 var (
239 r *dns.Msg
240 err error
241 errAll error
242 )
243
244 for _, ns := range nameservers {
245 r, err = sendDNSQuery(m, ns)
246 if err == nil && len(r.Answer) > 0 {
247 break
248 }
249
250 errAll = errors.Join(errAll, err)
251 }
252
253 if err != nil {
254 return r, errAll
255 }
256
257 return r, nil
258 }
259
260 func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
261 m := new(dns.Msg)
262 m.SetQuestion(fqdn, rtype)
263 m.SetEdns0(4096, false)
264
265 if !recursive {
266 m.RecursionDesired = false
267 }
268
269 return m
270 }
271
272 func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
273 if ok, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_DNS_TCP_ONLY")); ok {
274 tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
275
276 r, _, err := tcp.Exchange(m, ns)
277 if err != nil {
278 return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
279 }
280
281 return r, nil
282 }
283
284 udp := &dns.Client{Net: "udp", Timeout: dnsTimeout}
285 r, _, err := udp.Exchange(m, ns)
286
287 if r != nil && r.Truncated {
288 tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
289 // If the TCP request succeeds, the "err" will reset to nil
290 r, _, err = tcp.Exchange(m, ns)
291 }
292
293 if err != nil {
294 return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
295 }
296
297 return r, nil
298 }
299
300 // DNSError error related to DNS calls.
301 type DNSError struct {
302 Message string
303 NS string
304 MsgIn *dns.Msg
305 MsgOut *dns.Msg
306 Err error
307 }
308
309 func (d *DNSError) Error() string {
310 var details []string
311 if d.NS != "" {
312 details = append(details, "ns="+d.NS)
313 }
314
315 if d.MsgIn != nil && len(d.MsgIn.Question) > 0 {
316 details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgIn.Question)))
317 }
318
319 if d.MsgOut != nil {
320 if d.MsgIn == nil || len(d.MsgIn.Question) == 0 {
321 details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgOut.Question)))
322 }
323
324 details = append(details, "code="+dns.RcodeToString[d.MsgOut.Rcode])
325 }
326
327 msg := "DNS error"
328 if d.Message != "" {
329 msg = d.Message
330 }
331
332 if d.Err != nil {
333 msg += ": " + d.Err.Error()
334 }
335
336 if len(details) > 0 {
337 msg += " [" + strings.Join(details, ", ") + "]"
338 }
339
340 return msg
341 }
342
343 func (d *DNSError) Unwrap() error {
344 return d.Err
345 }
346
347 func formatQuestions(questions []dns.Question) string {
348 var parts []string
349 for _, question := range questions {
350 parts = append(parts, strings.ReplaceAll(strings.TrimPrefix(question.String(), ";"), "\t", " "))
351 }
352
353 return strings.Join(parts, ";")
354 }
355