dns.mx raw
1 package ws
2
3 import (
4 "fmt"
5 "os"
6 "sync"
7 "syscall"
8 "time"
9 )
10
11 const dnsTTL = 24 * time.Hour
12
13 // dnsAddr is read from /etc/resolv.conf at init time.
14 var dnsAddr [4]byte
15
16 func init() {
17 dnsAddr = [4]byte{8, 8, 8, 8} // fallback
18 data, err := os.ReadFile("/etc/resolv.conf")
19 if err != nil {
20 return
21 }
22 i := 0
23 for i < len(data) {
24 lineStart := i
25 for i < len(data) && data[i] != '\n' {
26 i++
27 }
28 line := data[lineStart:i]
29 if i < len(data) {
30 i++
31 }
32 if len(line) > 11 && string(line[:11]) == "nameserver " {
33 ns := line[11:]
34 for len(ns) > 0 && (ns[len(ns)-1] == ' ' || ns[len(ns)-1] == '\t' || ns[len(ns)-1] == '\r') {
35 ns = ns[:len(ns)-1]
36 }
37 if ip := parseIPv4(ns); ip != nil {
38 dnsAddr = [4]byte{ip[0], ip[1], ip[2], ip[3]}
39 return
40 }
41 }
42 }
43 }
44
45 func parseIPv4(s []byte) []byte {
46 var parts [4]byte
47 p := 0
48 v := 0
49 for i := 0; i <= len(s); i++ {
50 if i == len(s) || s[i] == '.' {
51 if v > 255 || p > 3 {
52 return nil
53 }
54 parts[p] = byte(v)
55 p++
56 v = 0
57 } else if s[i] >= '0' && s[i] <= '9' {
58 v = v*10 + int(s[i]-'0')
59 } else {
60 return nil
61 }
62 }
63 if p != 4 {
64 return nil
65 }
66 return parts[:]
67 }
68
69 type dnsEntry struct {
70 ip string
71 exp time.Time
72 }
73
74 var (
75 dnsMu sync.Mutex
76 dnsCache = map[string]dnsEntry{}
77 )
78
79 // resolveHost returns a cached IP for the hostname, or resolves via raw UDP DNS.
80 func resolveHost(host string) (string, error) {
81 dnsMu.Lock()
82 if e, ok := dnsCache[host]; ok && time.Now().Before(e.exp) {
83 dnsMu.Unlock()
84 return e.ip, nil
85 }
86 dnsMu.Unlock()
87
88 ip, err := dnsLookup(host)
89 if err != nil {
90 return "", err
91 }
92
93 dnsMu.Lock()
94 dnsCache[host] = dnsEntry{ip: ip, exp: time.Now().Add(dnsTTL)}
95 dnsMu.Unlock()
96
97 return ip, nil
98 }
99
100 // dnsLookup sends a minimal DNS A-record query via raw syscalls.
101 func dnsLookup(host string) (string, error) {
102 fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, 0)
103 if err != nil {
104 return "", fmt.Errorf("dns: socket: %w", err)
105 }
106 defer syscall.Close(fd)
107
108 tv := syscall.Timeval{Sec: 5}
109 syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &tv)
110
111 sa := &syscall.SockaddrInet4{Port: 53, Addr: dnsAddr}
112 if err := syscall.Connect(fd, sa); err != nil {
113 return "", fmt.Errorf("dns: connect: %w", err)
114 }
115
116 // Build DNS A-record query.
117 var pkt []byte
118 pkt = append(pkt, 0xAB, 0xCD) // ID
119 pkt = append(pkt, 0x01, 0x00) // flags: recursion desired
120 pkt = append(pkt, 0x00, 0x01) // 1 question
121 pkt = append(pkt, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00) // 0 answer/auth/additional
122
123 // Encode QNAME.
124 j := 0
125 for j < len(host) {
126 dot := j
127 for dot < len(host) && host[dot] != '.' {
128 dot++
129 }
130 pkt = append(pkt, byte(dot-j))
131 pkt = append(pkt, host[j:dot]...)
132 j = dot + 1
133 }
134 pkt = append(pkt, 0x00) // root
135 pkt = append(pkt, 0x00, 0x01) // QTYPE A
136 pkt = append(pkt, 0x00, 0x01) // QCLASS IN
137
138 if err := syscall.Sendto(fd, pkt, 0, sa); err != nil {
139 return "", fmt.Errorf("dns: send: %w", err)
140 }
141
142 buf := []byte{:512}
143 n, _, err := syscall.Recvfrom(fd, buf, 0)
144 if err != nil {
145 return "", fmt.Errorf("dns: recv: %w", err)
146 }
147 if n < 12 {
148 return "", fmt.Errorf("dns: response too short")
149 }
150
151 anCount := int(buf[6])<<8 | int(buf[7])
152 if anCount == 0 {
153 return "", fmt.Errorf("dns: no answers for %s", host)
154 }
155
156 // Skip question section.
157 pos := 12
158 for pos < n {
159 if buf[pos] == 0 {
160 pos++
161 break
162 }
163 if buf[pos]&0xC0 == 0xC0 {
164 pos += 2
165 break
166 }
167 pos += int(buf[pos]) + 1
168 }
169 pos += 4 // QTYPE + QCLASS
170
171 // Find first A record.
172 for a := 0; a < anCount && pos+10 < n; a++ {
173 if buf[pos]&0xC0 == 0xC0 {
174 pos += 2
175 } else {
176 for pos < n && buf[pos] != 0 {
177 pos += int(buf[pos]) + 1
178 }
179 pos++
180 }
181 if pos+10 > n {
182 break
183 }
184 rtype := int(buf[pos])<<8 | int(buf[pos+1])
185 rdlen := int(buf[pos+8])<<8 | int(buf[pos+9])
186 pos += 10
187 if rtype == 1 && rdlen == 4 && pos+4 <= n {
188 // Copy Sprintf result to avoid buffer reuse corruption.
189 s := fmt.Sprintf("%d.%d.%d.%d", buf[pos], buf[pos+1], buf[pos+2], buf[pos+3])
190 ip := []byte{:len(s)}
191 copy(ip, s)
192 return string(ip), nil
193 }
194 pos += rdlen
195 }
196
197 return "", fmt.Errorf("dns: no A record for %s", host)
198 }
199