package ws import ( "fmt" "os" "sync" "syscall" "time" ) const dnsTTL = 24 * time.Hour // dnsAddr is read from /etc/resolv.conf at init time. var dnsAddr [4]byte func init() { dnsAddr = [4]byte{8, 8, 8, 8} // fallback data, err := os.ReadFile("/etc/resolv.conf") if err != nil { return } i := 0 for i < len(data) { lineStart := i for i < len(data) && data[i] != '\n' { i++ } line := data[lineStart:i] if i < len(data) { i++ } if len(line) > 11 && string(line[:11]) == "nameserver " { ns := line[11:] for len(ns) > 0 && (ns[len(ns)-1] == ' ' || ns[len(ns)-1] == '\t' || ns[len(ns)-1] == '\r') { ns = ns[:len(ns)-1] } if ip := parseIPv4(ns); ip != nil { dnsAddr = [4]byte{ip[0], ip[1], ip[2], ip[3]} return } } } } func parseIPv4(s []byte) []byte { var parts [4]byte p := 0 v := 0 for i := 0; i <= len(s); i++ { if i == len(s) || s[i] == '.' { if v > 255 || p > 3 { return nil } parts[p] = byte(v) p++ v = 0 } else if s[i] >= '0' && s[i] <= '9' { v = v*10 + int(s[i]-'0') } else { return nil } } if p != 4 { return nil } return parts[:] } type dnsEntry struct { ip string exp time.Time } var ( dnsMu sync.Mutex dnsCache = map[string]dnsEntry{} ) // resolveHost returns a cached IP for the hostname, or resolves via raw UDP DNS. func resolveHost(host string) (string, error) { dnsMu.Lock() if e, ok := dnsCache[host]; ok && time.Now().Before(e.exp) { dnsMu.Unlock() return e.ip, nil } dnsMu.Unlock() ip, err := dnsLookup(host) if err != nil { return "", err } dnsMu.Lock() dnsCache[host] = dnsEntry{ip: ip, exp: time.Now().Add(dnsTTL)} dnsMu.Unlock() return ip, nil } // dnsLookup sends a minimal DNS A-record query via raw syscalls. func dnsLookup(host string) (string, error) { fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, 0) if err != nil { return "", fmt.Errorf("dns: socket: %w", err) } defer syscall.Close(fd) tv := syscall.Timeval{Sec: 5} syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &tv) sa := &syscall.SockaddrInet4{Port: 53, Addr: dnsAddr} if err := syscall.Connect(fd, sa); err != nil { return "", fmt.Errorf("dns: connect: %w", err) } // Build DNS A-record query. var pkt []byte pkt = append(pkt, 0xAB, 0xCD) // ID pkt = append(pkt, 0x01, 0x00) // flags: recursion desired pkt = append(pkt, 0x00, 0x01) // 1 question pkt = append(pkt, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00) // 0 answer/auth/additional // Encode QNAME. j := 0 for j < len(host) { dot := j for dot < len(host) && host[dot] != '.' { dot++ } pkt = append(pkt, byte(dot-j)) pkt = append(pkt, host[j:dot]...) j = dot + 1 } pkt = append(pkt, 0x00) // root pkt = append(pkt, 0x00, 0x01) // QTYPE A pkt = append(pkt, 0x00, 0x01) // QCLASS IN if err := syscall.Sendto(fd, pkt, 0, sa); err != nil { return "", fmt.Errorf("dns: send: %w", err) } buf := []byte{:512} n, _, err := syscall.Recvfrom(fd, buf, 0) if err != nil { return "", fmt.Errorf("dns: recv: %w", err) } if n < 12 { return "", fmt.Errorf("dns: response too short") } anCount := int(buf[6])<<8 | int(buf[7]) if anCount == 0 { return "", fmt.Errorf("dns: no answers for %s", host) } // Skip question section. pos := 12 for pos < n { if buf[pos] == 0 { pos++ break } if buf[pos]&0xC0 == 0xC0 { pos += 2 break } pos += int(buf[pos]) + 1 } pos += 4 // QTYPE + QCLASS // Find first A record. for a := 0; a < anCount && pos+10 < n; a++ { if buf[pos]&0xC0 == 0xC0 { pos += 2 } else { for pos < n && buf[pos] != 0 { pos += int(buf[pos]) + 1 } pos++ } if pos+10 > n { break } rtype := int(buf[pos])<<8 | int(buf[pos+1]) rdlen := int(buf[pos+8])<<8 | int(buf[pos+9]) pos += 10 if rtype == 1 && rdlen == 4 && pos+4 <= n { // Copy Sprintf result to avoid buffer reuse corruption. s := fmt.Sprintf("%d.%d.%d.%d", buf[pos], buf[pos+1], buf[pos+2], buf[pos+3]) ip := []byte{:len(s)} copy(ip, s) return string(ip), nil } pos += rdlen } return "", fmt.Errorf("dns: no A record for %s", host) }