nameserver.go raw

   1  package dns01
   2  
   3  import (
   4  	"errors"
   5  	"fmt"
   6  	"net"
   7  	"os"
   8  	"slices"
   9  	"strconv"
  10  	"strings"
  11  	"sync"
  12  	"time"
  13  
  14  	"github.com/miekg/dns"
  15  )
  16  
  17  const defaultResolvConf = "/etc/resolv.conf"
  18  
  19  var fqdnSoaCache = &sync.Map{}
  20  
  21  var defaultNameservers = []string{
  22  	"google-public-dns-a.google.com:53",
  23  	"google-public-dns-b.google.com:53",
  24  }
  25  
  26  // recursiveNameservers are used to pre-check DNS propagation.
  27  var recursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers)
  28  
  29  // soaCacheEntry holds a cached SOA record (only selected fields).
  30  type soaCacheEntry struct {
  31  	zone      string    // zone apex (a domain name)
  32  	primaryNs string    // primary nameserver for the zone apex
  33  	expires   time.Time // time when this cache entry should be evicted
  34  }
  35  
  36  func newSoaCacheEntry(soa *dns.SOA) *soaCacheEntry {
  37  	return &soaCacheEntry{
  38  		zone:      soa.Hdr.Name,
  39  		primaryNs: soa.Ns,
  40  		expires:   time.Now().Add(time.Duration(soa.Refresh) * time.Second),
  41  	}
  42  }
  43  
  44  // isExpired checks whether a cache entry should be considered expired.
  45  func (cache *soaCacheEntry) isExpired() bool {
  46  	return time.Now().After(cache.expires)
  47  }
  48  
  49  // ClearFqdnCache clears the cache of fqdn to zone mappings. Primarily used in testing.
  50  func ClearFqdnCache() {
  51  	// TODO(ldez): use `fqdnSoaCache.Clear()` when updating to go1.23
  52  	fqdnSoaCache.Range(func(k, v any) bool {
  53  		fqdnSoaCache.Delete(k)
  54  		return true
  55  	})
  56  }
  57  
  58  func AddDNSTimeout(timeout time.Duration) ChallengeOption {
  59  	return func(_ *Challenge) error {
  60  		dnsTimeout = timeout
  61  		return nil
  62  	}
  63  }
  64  
  65  func AddRecursiveNameservers(nameservers []string) ChallengeOption {
  66  	return func(_ *Challenge) error {
  67  		recursiveNameservers = ParseNameservers(nameservers)
  68  		return nil
  69  	}
  70  }
  71  
  72  // getNameservers attempts to get systems nameservers before falling back to the defaults.
  73  func getNameservers(path string, defaults []string) []string {
  74  	config, err := dns.ClientConfigFromFile(path)
  75  	if err != nil || len(config.Servers) == 0 {
  76  		return defaults
  77  	}
  78  
  79  	return ParseNameservers(config.Servers)
  80  }
  81  
  82  func ParseNameservers(servers []string) []string {
  83  	var resolvers []string
  84  
  85  	for _, resolver := range servers {
  86  		// ensure all servers have a port number
  87  		if _, _, err := net.SplitHostPort(resolver); err != nil {
  88  			resolvers = append(resolvers, net.JoinHostPort(resolver, "53"))
  89  		} else {
  90  			resolvers = append(resolvers, resolver)
  91  		}
  92  	}
  93  
  94  	return resolvers
  95  }
  96  
  97  // lookupNameservers returns the authoritative nameservers for the given fqdn.
  98  func lookupNameservers(fqdn string) ([]string, error) {
  99  	var authoritativeNss []string
 100  
 101  	zone, err := FindZoneByFqdn(fqdn)
 102  	if err != nil {
 103  		return nil, fmt.Errorf("could not find zone: %w", err)
 104  	}
 105  
 106  	r, err := dnsQuery(zone, dns.TypeNS, recursiveNameservers, true)
 107  	if err != nil {
 108  		return nil, fmt.Errorf("NS call failed: %w", err)
 109  	}
 110  
 111  	for _, rr := range r.Answer {
 112  		if ns, ok := rr.(*dns.NS); ok {
 113  			authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
 114  		}
 115  	}
 116  
 117  	if len(authoritativeNss) > 0 {
 118  		return authoritativeNss, nil
 119  	}
 120  
 121  	return nil, fmt.Errorf("[zone=%s] could not determine authoritative nameservers", zone)
 122  }
 123  
 124  // FindPrimaryNsByFqdn determines the primary nameserver of the zone apex for the given fqdn
 125  // by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
 126  func FindPrimaryNsByFqdn(fqdn string) (string, error) {
 127  	return FindPrimaryNsByFqdnCustom(fqdn, recursiveNameservers)
 128  }
 129  
 130  // FindPrimaryNsByFqdnCustom determines the primary nameserver of the zone apex for the given fqdn
 131  // by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
 132  func FindPrimaryNsByFqdnCustom(fqdn string, nameservers []string) (string, error) {
 133  	soa, err := lookupSoaByFqdn(fqdn, nameservers)
 134  	if err != nil {
 135  		return "", fmt.Errorf("[fqdn=%s] %w", fqdn, err)
 136  	}
 137  
 138  	return soa.primaryNs, nil
 139  }
 140  
 141  // FindZoneByFqdn determines the zone apex for the given fqdn
 142  // by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
 143  func FindZoneByFqdn(fqdn string) (string, error) {
 144  	return FindZoneByFqdnCustom(fqdn, recursiveNameservers)
 145  }
 146  
 147  // FindZoneByFqdnCustom determines the zone apex for the given fqdn
 148  // by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
 149  func FindZoneByFqdnCustom(fqdn string, nameservers []string) (string, error) {
 150  	soa, err := lookupSoaByFqdn(fqdn, nameservers)
 151  	if err != nil {
 152  		return "", fmt.Errorf("[fqdn=%s] %w", fqdn, err)
 153  	}
 154  
 155  	return soa.zone, nil
 156  }
 157  
 158  func lookupSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) {
 159  	// Do we have it cached and is it still fresh?
 160  	entAny, ok := fqdnSoaCache.Load(fqdn)
 161  	if ok && entAny != nil {
 162  		ent, ok1 := entAny.(*soaCacheEntry)
 163  		if ok1 && !ent.isExpired() {
 164  			return ent, nil
 165  		}
 166  	}
 167  
 168  	ent, err := fetchSoaByFqdn(fqdn, nameservers)
 169  	if err != nil {
 170  		return nil, err
 171  	}
 172  
 173  	fqdnSoaCache.Store(fqdn, ent)
 174  
 175  	return ent, nil
 176  }
 177  
 178  func fetchSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) {
 179  	var (
 180  		err error
 181  		r   *dns.Msg
 182  	)
 183  
 184  	for domain := range DomainsSeq(fqdn) {
 185  		r, err = dnsQuery(domain, dns.TypeSOA, nameservers, true)
 186  		if err != nil {
 187  			continue
 188  		}
 189  
 190  		if r == nil {
 191  			continue
 192  		}
 193  
 194  		switch r.Rcode {
 195  		case dns.RcodeSuccess:
 196  			// Check if we got a SOA RR in the answer section
 197  			if len(r.Answer) == 0 {
 198  				continue
 199  			}
 200  
 201  			// CNAME records cannot/should not exist at the root of a zone.
 202  			// So we skip a domain when a CNAME is found.
 203  			if dnsMsgContainsCNAME(r) {
 204  				continue
 205  			}
 206  
 207  			for _, ans := range r.Answer {
 208  				if soa, ok := ans.(*dns.SOA); ok {
 209  					return newSoaCacheEntry(soa), nil
 210  				}
 211  			}
 212  		case dns.RcodeNameError:
 213  			// NXDOMAIN
 214  		default:
 215  			// Any response code other than NOERROR and NXDOMAIN is treated as error
 216  			return nil, &DNSError{Message: fmt.Sprintf("unexpected response for '%s'", domain), MsgOut: r}
 217  		}
 218  	}
 219  
 220  	return nil, &DNSError{Message: fmt.Sprintf("could not find the start of authority for '%s'", fqdn), MsgOut: r, Err: err}
 221  }
 222  
 223  // dnsMsgContainsCNAME checks for a CNAME answer in msg.
 224  func dnsMsgContainsCNAME(msg *dns.Msg) bool {
 225  	return slices.ContainsFunc(msg.Answer, func(rr dns.RR) bool {
 226  		_, ok := rr.(*dns.CNAME)
 227  		return ok
 228  	})
 229  }
 230  
 231  func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
 232  	m := createDNSMsg(fqdn, rtype, recursive)
 233  
 234  	if len(nameservers) == 0 {
 235  		return nil, &DNSError{Message: "empty list of nameservers"}
 236  	}
 237  
 238  	var (
 239  		r      *dns.Msg
 240  		err    error
 241  		errAll error
 242  	)
 243  
 244  	for _, ns := range nameservers {
 245  		r, err = sendDNSQuery(m, ns)
 246  		if err == nil && len(r.Answer) > 0 {
 247  			break
 248  		}
 249  
 250  		errAll = errors.Join(errAll, err)
 251  	}
 252  
 253  	if err != nil {
 254  		return r, errAll
 255  	}
 256  
 257  	return r, nil
 258  }
 259  
 260  func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
 261  	m := new(dns.Msg)
 262  	m.SetQuestion(fqdn, rtype)
 263  	m.SetEdns0(4096, false)
 264  
 265  	if !recursive {
 266  		m.RecursionDesired = false
 267  	}
 268  
 269  	return m
 270  }
 271  
 272  func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
 273  	if ok, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_DNS_TCP_ONLY")); ok {
 274  		tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
 275  
 276  		r, _, err := tcp.Exchange(m, ns)
 277  		if err != nil {
 278  			return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
 279  		}
 280  
 281  		return r, nil
 282  	}
 283  
 284  	udp := &dns.Client{Net: "udp", Timeout: dnsTimeout}
 285  	r, _, err := udp.Exchange(m, ns)
 286  
 287  	if r != nil && r.Truncated {
 288  		tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
 289  		// If the TCP request succeeds, the "err" will reset to nil
 290  		r, _, err = tcp.Exchange(m, ns)
 291  	}
 292  
 293  	if err != nil {
 294  		return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
 295  	}
 296  
 297  	return r, nil
 298  }
 299  
 300  // DNSError error related to DNS calls.
 301  type DNSError struct {
 302  	Message string
 303  	NS      string
 304  	MsgIn   *dns.Msg
 305  	MsgOut  *dns.Msg
 306  	Err     error
 307  }
 308  
 309  func (d *DNSError) Error() string {
 310  	var details []string
 311  	if d.NS != "" {
 312  		details = append(details, "ns="+d.NS)
 313  	}
 314  
 315  	if d.MsgIn != nil && len(d.MsgIn.Question) > 0 {
 316  		details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgIn.Question)))
 317  	}
 318  
 319  	if d.MsgOut != nil {
 320  		if d.MsgIn == nil || len(d.MsgIn.Question) == 0 {
 321  			details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgOut.Question)))
 322  		}
 323  
 324  		details = append(details, "code="+dns.RcodeToString[d.MsgOut.Rcode])
 325  	}
 326  
 327  	msg := "DNS error"
 328  	if d.Message != "" {
 329  		msg = d.Message
 330  	}
 331  
 332  	if d.Err != nil {
 333  		msg += ": " + d.Err.Error()
 334  	}
 335  
 336  	if len(details) > 0 {
 337  		msg += " [" + strings.Join(details, ", ") + "]"
 338  	}
 339  
 340  	return msg
 341  }
 342  
 343  func (d *DNSError) Unwrap() error {
 344  	return d.Err
 345  }
 346  
 347  func formatQuestions(questions []dns.Question) string {
 348  	var parts []string
 349  	for _, question := range questions {
 350  		parts = append(parts, strings.ReplaceAll(strings.TrimPrefix(question.String(), ";"), "\t", " "))
 351  	}
 352  
 353  	return strings.Join(parts, ";")
 354  }
 355