dns.mx raw

   1  package ws
   2  
   3  import (
   4  	"fmt"
   5  	"os"
   6  	"sync"
   7  	"syscall"
   8  	"time"
   9  )
  10  
  11  const dnsTTL = 24 * time.Hour
  12  
  13  // dnsAddr is read from /etc/resolv.conf at init time.
  14  var dnsAddr [4]byte
  15  
  16  func init() {
  17  	dnsAddr = [4]byte{8, 8, 8, 8} // fallback
  18  	data, err := os.ReadFile("/etc/resolv.conf")
  19  	if err != nil {
  20  		return
  21  	}
  22  	i := 0
  23  	for i < len(data) {
  24  		lineStart := i
  25  		for i < len(data) && data[i] != '\n' {
  26  			i++
  27  		}
  28  		line := data[lineStart:i]
  29  		if i < len(data) {
  30  			i++
  31  		}
  32  		if len(line) > 11 && string(line[:11]) == "nameserver " {
  33  			ns := line[11:]
  34  			for len(ns) > 0 && (ns[len(ns)-1] == ' ' || ns[len(ns)-1] == '\t' || ns[len(ns)-1] == '\r') {
  35  				ns = ns[:len(ns)-1]
  36  			}
  37  			if ip := parseIPv4(ns); ip != nil {
  38  				dnsAddr = [4]byte{ip[0], ip[1], ip[2], ip[3]}
  39  				return
  40  			}
  41  		}
  42  	}
  43  }
  44  
  45  func parseIPv4(s []byte) []byte {
  46  	var parts [4]byte
  47  	p := 0
  48  	v := 0
  49  	for i := 0; i <= len(s); i++ {
  50  		if i == len(s) || s[i] == '.' {
  51  			if v > 255 || p > 3 {
  52  				return nil
  53  			}
  54  			parts[p] = byte(v)
  55  			p++
  56  			v = 0
  57  		} else if s[i] >= '0' && s[i] <= '9' {
  58  			v = v*10 + int(s[i]-'0')
  59  		} else {
  60  			return nil
  61  		}
  62  	}
  63  	if p != 4 {
  64  		return nil
  65  	}
  66  	return parts[:]
  67  }
  68  
  69  type dnsEntry struct {
  70  	ip  string
  71  	exp time.Time
  72  }
  73  
  74  var (
  75  	dnsMu    sync.Mutex
  76  	dnsCache = map[string]dnsEntry{}
  77  )
  78  
  79  // resolveHost returns a cached IP for the hostname, or resolves via raw UDP DNS.
  80  func resolveHost(host string) (string, error) {
  81  	dnsMu.Lock()
  82  	if e, ok := dnsCache[host]; ok && time.Now().Before(e.exp) {
  83  		dnsMu.Unlock()
  84  		return e.ip, nil
  85  	}
  86  	dnsMu.Unlock()
  87  
  88  	ip, err := dnsLookup(host)
  89  	if err != nil {
  90  		return "", err
  91  	}
  92  
  93  	dnsMu.Lock()
  94  	dnsCache[host] = dnsEntry{ip: ip, exp: time.Now().Add(dnsTTL)}
  95  	dnsMu.Unlock()
  96  
  97  	return ip, nil
  98  }
  99  
 100  // dnsLookup sends a minimal DNS A-record query via raw syscalls.
 101  func dnsLookup(host string) (string, error) {
 102  	fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, 0)
 103  	if err != nil {
 104  		return "", fmt.Errorf("dns: socket: %w", err)
 105  	}
 106  	defer syscall.Close(fd)
 107  
 108  	tv := syscall.Timeval{Sec: 5}
 109  	syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &tv)
 110  
 111  	sa := &syscall.SockaddrInet4{Port: 53, Addr: dnsAddr}
 112  	if err := syscall.Connect(fd, sa); err != nil {
 113  		return "", fmt.Errorf("dns: connect: %w", err)
 114  	}
 115  
 116  	// Build DNS A-record query.
 117  	var pkt []byte
 118  	pkt = append(pkt, 0xAB, 0xCD)                         // ID
 119  	pkt = append(pkt, 0x01, 0x00)                         // flags: recursion desired
 120  	pkt = append(pkt, 0x00, 0x01)                         // 1 question
 121  	pkt = append(pkt, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00) // 0 answer/auth/additional
 122  
 123  	// Encode QNAME.
 124  	j := 0
 125  	for j < len(host) {
 126  		dot := j
 127  		for dot < len(host) && host[dot] != '.' {
 128  			dot++
 129  		}
 130  		pkt = append(pkt, byte(dot-j))
 131  		pkt = append(pkt, host[j:dot]...)
 132  		j = dot + 1
 133  	}
 134  	pkt = append(pkt, 0x00)       // root
 135  	pkt = append(pkt, 0x00, 0x01) // QTYPE A
 136  	pkt = append(pkt, 0x00, 0x01) // QCLASS IN
 137  
 138  	if err := syscall.Sendto(fd, pkt, 0, sa); err != nil {
 139  		return "", fmt.Errorf("dns: send: %w", err)
 140  	}
 141  
 142  	buf := []byte{:512}
 143  	n, _, err := syscall.Recvfrom(fd, buf, 0)
 144  	if err != nil {
 145  		return "", fmt.Errorf("dns: recv: %w", err)
 146  	}
 147  	if n < 12 {
 148  		return "", fmt.Errorf("dns: response too short")
 149  	}
 150  
 151  	anCount := int(buf[6])<<8 | int(buf[7])
 152  	if anCount == 0 {
 153  		return "", fmt.Errorf("dns: no answers for %s", host)
 154  	}
 155  
 156  	// Skip question section.
 157  	pos := 12
 158  	for pos < n {
 159  		if buf[pos] == 0 {
 160  			pos++
 161  			break
 162  		}
 163  		if buf[pos]&0xC0 == 0xC0 {
 164  			pos += 2
 165  			break
 166  		}
 167  		pos += int(buf[pos]) + 1
 168  	}
 169  	pos += 4 // QTYPE + QCLASS
 170  
 171  	// Find first A record.
 172  	for a := 0; a < anCount && pos+10 < n; a++ {
 173  		if buf[pos]&0xC0 == 0xC0 {
 174  			pos += 2
 175  		} else {
 176  			for pos < n && buf[pos] != 0 {
 177  				pos += int(buf[pos]) + 1
 178  			}
 179  			pos++
 180  		}
 181  		if pos+10 > n {
 182  			break
 183  		}
 184  		rtype := int(buf[pos])<<8 | int(buf[pos+1])
 185  		rdlen := int(buf[pos+8])<<8 | int(buf[pos+9])
 186  		pos += 10
 187  		if rtype == 1 && rdlen == 4 && pos+4 <= n {
 188  			// Copy Sprintf result to avoid buffer reuse corruption.
 189  			s := fmt.Sprintf("%d.%d.%d.%d", buf[pos], buf[pos+1], buf[pos+2], buf[pos+3])
 190  			ip := []byte{:len(s)}
 191  			copy(ip, s)
 192  			return string(ip), nil
 193  		}
 194  		pos += rdlen
 195  	}
 196  
 197  	return "", fmt.Errorf("dns: no A record for %s", host)
 198  }
 199