tun.go raw

   1  /* SPDX-License-Identifier: MIT
   2   *
   3   * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
   4   */
   5  
   6  package netstack
   7  
   8  import (
   9  	"bytes"
  10  	"context"
  11  	"crypto/rand"
  12  	"encoding/binary"
  13  	"errors"
  14  	"fmt"
  15  	"io"
  16  	"net"
  17  	"net/netip"
  18  	"os"
  19  	"regexp"
  20  	"strconv"
  21  	"strings"
  22  	"syscall"
  23  	"time"
  24  
  25  	"golang.zx2c4.com/wireguard/tun"
  26  
  27  	"golang.org/x/net/dns/dnsmessage"
  28  	"gvisor.dev/gvisor/pkg/buffer"
  29  	"gvisor.dev/gvisor/pkg/tcpip"
  30  	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
  31  	"gvisor.dev/gvisor/pkg/tcpip/header"
  32  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
  33  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
  34  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
  35  	"gvisor.dev/gvisor/pkg/tcpip/stack"
  36  	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
  37  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
  38  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
  39  	"gvisor.dev/gvisor/pkg/waiter"
  40  )
  41  
  42  type netTun struct {
  43  	ep             *channel.Endpoint
  44  	stack          *stack.Stack
  45  	events         chan tun.Event
  46  	notifyHandle   *channel.NotificationHandle
  47  	incomingPacket chan *buffer.View
  48  	mtu            int
  49  	dnsServers     []netip.Addr
  50  	hasV4, hasV6   bool
  51  }
  52  
  53  type Net netTun
  54  
  55  func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
  56  	opts := stack.Options{
  57  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  58  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
  59  		HandleLocal:        true,
  60  	}
  61  	dev := &netTun{
  62  		ep:             channel.New(1024, uint32(mtu), ""),
  63  		stack:          stack.New(opts),
  64  		events:         make(chan tun.Event, 10),
  65  		incomingPacket: make(chan *buffer.View),
  66  		dnsServers:     dnsServers,
  67  		mtu:            mtu,
  68  	}
  69  	sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
  70  	tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
  71  	if tcpipErr != nil {
  72  		return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
  73  	}
  74  	dev.notifyHandle = dev.ep.AddNotify(dev)
  75  	tcpipErr = dev.stack.CreateNIC(1, dev.ep)
  76  	if tcpipErr != nil {
  77  		return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
  78  	}
  79  	for _, ip := range localAddresses {
  80  		var protoNumber tcpip.NetworkProtocolNumber
  81  		if ip.Is4() {
  82  			protoNumber = ipv4.ProtocolNumber
  83  		} else if ip.Is6() {
  84  			protoNumber = ipv6.ProtocolNumber
  85  		}
  86  		protoAddr := tcpip.ProtocolAddress{
  87  			Protocol:          protoNumber,
  88  			AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
  89  		}
  90  		tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
  91  		if tcpipErr != nil {
  92  			return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
  93  		}
  94  		if ip.Is4() {
  95  			dev.hasV4 = true
  96  		} else if ip.Is6() {
  97  			dev.hasV6 = true
  98  		}
  99  	}
 100  	if dev.hasV4 {
 101  		dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
 102  	}
 103  	if dev.hasV6 {
 104  		dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
 105  	}
 106  
 107  	dev.events <- tun.EventUp
 108  	return dev, (*Net)(dev), nil
 109  }
 110  
 111  func (tun *netTun) Name() (string, error) {
 112  	return "go", nil
 113  }
 114  
 115  func (tun *netTun) File() *os.File {
 116  	return nil
 117  }
 118  
 119  func (tun *netTun) Events() <-chan tun.Event {
 120  	return tun.events
 121  }
 122  
 123  func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
 124  	view, ok := <-tun.incomingPacket
 125  	if !ok {
 126  		return 0, os.ErrClosed
 127  	}
 128  
 129  	n, err := view.Read(buf[0][offset:])
 130  	if err != nil {
 131  		return 0, err
 132  	}
 133  	sizes[0] = n
 134  	return 1, nil
 135  }
 136  
 137  func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
 138  	for _, buf := range buf {
 139  		packet := buf[offset:]
 140  		if len(packet) == 0 {
 141  			continue
 142  		}
 143  
 144  		pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
 145  		switch packet[0] >> 4 {
 146  		case 4:
 147  			tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
 148  		case 6:
 149  			tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
 150  		default:
 151  			return 0, syscall.EAFNOSUPPORT
 152  		}
 153  	}
 154  	return len(buf), nil
 155  }
 156  
 157  func (tun *netTun) WriteNotify() {
 158  	pkt := tun.ep.Read()
 159  	if pkt == nil {
 160  		return
 161  	}
 162  
 163  	view := pkt.ToView()
 164  	pkt.DecRef()
 165  
 166  	tun.incomingPacket <- view
 167  }
 168  
 169  func (tun *netTun) Close() error {
 170  	tun.stack.RemoveNIC(1)
 171  	tun.stack.Close()
 172  	tun.ep.RemoveNotify(tun.notifyHandle)
 173  	tun.ep.Close()
 174  
 175  	if tun.events != nil {
 176  		close(tun.events)
 177  	}
 178  
 179  	if tun.incomingPacket != nil {
 180  		close(tun.incomingPacket)
 181  	}
 182  
 183  	return nil
 184  }
 185  
 186  func (tun *netTun) MTU() (int, error) {
 187  	return tun.mtu, nil
 188  }
 189  
 190  func (tun *netTun) BatchSize() int {
 191  	return 1
 192  }
 193  
 194  func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
 195  	var protoNumber tcpip.NetworkProtocolNumber
 196  	if endpoint.Addr().Is4() {
 197  		protoNumber = ipv4.ProtocolNumber
 198  	} else {
 199  		protoNumber = ipv6.ProtocolNumber
 200  	}
 201  	return tcpip.FullAddress{
 202  		NIC:  1,
 203  		Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
 204  		Port: endpoint.Port(),
 205  	}, protoNumber
 206  }
 207  
 208  func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
 209  	fa, pn := convertToFullAddr(addr)
 210  	return gonet.DialContextTCP(ctx, net.stack, fa, pn)
 211  }
 212  
 213  func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
 214  	if addr == nil {
 215  		return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
 216  	}
 217  	ip, _ := netip.AddrFromSlice(addr.IP)
 218  	return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
 219  }
 220  
 221  func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
 222  	fa, pn := convertToFullAddr(addr)
 223  	return gonet.DialTCP(net.stack, fa, pn)
 224  }
 225  
 226  func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
 227  	if addr == nil {
 228  		return net.DialTCPAddrPort(netip.AddrPort{})
 229  	}
 230  	ip, _ := netip.AddrFromSlice(addr.IP)
 231  	return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
 232  }
 233  
 234  func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
 235  	fa, pn := convertToFullAddr(addr)
 236  	return gonet.ListenTCP(net.stack, fa, pn)
 237  }
 238  
 239  func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
 240  	if addr == nil {
 241  		return net.ListenTCPAddrPort(netip.AddrPort{})
 242  	}
 243  	ip, _ := netip.AddrFromSlice(addr.IP)
 244  	return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
 245  }
 246  
 247  func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
 248  	var lfa, rfa *tcpip.FullAddress
 249  	var pn tcpip.NetworkProtocolNumber
 250  	if laddr.IsValid() || laddr.Port() > 0 {
 251  		var addr tcpip.FullAddress
 252  		addr, pn = convertToFullAddr(laddr)
 253  		lfa = &addr
 254  	}
 255  	if raddr.IsValid() || raddr.Port() > 0 {
 256  		var addr tcpip.FullAddress
 257  		addr, pn = convertToFullAddr(raddr)
 258  		rfa = &addr
 259  	}
 260  	return gonet.DialUDP(net.stack, lfa, rfa, pn)
 261  }
 262  
 263  func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
 264  	return net.DialUDPAddrPort(laddr, netip.AddrPort{})
 265  }
 266  
 267  func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
 268  	var la, ra netip.AddrPort
 269  	if laddr != nil {
 270  		ip, _ := netip.AddrFromSlice(laddr.IP)
 271  		la = netip.AddrPortFrom(ip, uint16(laddr.Port))
 272  	}
 273  	if raddr != nil {
 274  		ip, _ := netip.AddrFromSlice(raddr.IP)
 275  		ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
 276  	}
 277  	return net.DialUDPAddrPort(la, ra)
 278  }
 279  
 280  func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
 281  	return net.DialUDP(laddr, nil)
 282  }
 283  
 284  type PingConn struct {
 285  	laddr    PingAddr
 286  	raddr    PingAddr
 287  	wq       waiter.Queue
 288  	ep       tcpip.Endpoint
 289  	deadline *time.Timer
 290  }
 291  
 292  type PingAddr struct{ addr netip.Addr }
 293  
 294  func (ia PingAddr) String() string {
 295  	return ia.addr.String()
 296  }
 297  
 298  func (ia PingAddr) Network() string {
 299  	if ia.addr.Is4() {
 300  		return "ping4"
 301  	} else if ia.addr.Is6() {
 302  		return "ping6"
 303  	}
 304  	return "ping"
 305  }
 306  
 307  func (ia PingAddr) Addr() netip.Addr {
 308  	return ia.addr
 309  }
 310  
 311  func PingAddrFromAddr(addr netip.Addr) *PingAddr {
 312  	return &PingAddr{addr}
 313  }
 314  
 315  func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
 316  	if !laddr.IsValid() && !raddr.IsValid() {
 317  		return nil, errors.New("ping dial: invalid address")
 318  	}
 319  	v6 := laddr.Is6() || raddr.Is6()
 320  	bind := laddr.IsValid()
 321  	if !bind {
 322  		if v6 {
 323  			laddr = netip.IPv6Unspecified()
 324  		} else {
 325  			laddr = netip.IPv4Unspecified()
 326  		}
 327  	}
 328  
 329  	tn := icmp.ProtocolNumber4
 330  	pn := ipv4.ProtocolNumber
 331  	if v6 {
 332  		tn = icmp.ProtocolNumber6
 333  		pn = ipv6.ProtocolNumber
 334  	}
 335  
 336  	pc := &PingConn{
 337  		laddr:    PingAddr{laddr},
 338  		deadline: time.NewTimer(time.Hour << 10),
 339  	}
 340  	pc.deadline.Stop()
 341  
 342  	ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
 343  	if tcpipErr != nil {
 344  		return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
 345  	}
 346  	pc.ep = ep
 347  
 348  	if bind {
 349  		fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
 350  		if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
 351  			return nil, fmt.Errorf("ping bind: %s", tcpipErr)
 352  		}
 353  	}
 354  
 355  	if raddr.IsValid() {
 356  		pc.raddr = PingAddr{raddr}
 357  		fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
 358  		if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
 359  			return nil, fmt.Errorf("ping connect: %s", tcpipErr)
 360  		}
 361  	}
 362  
 363  	return pc, nil
 364  }
 365  
 366  func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) {
 367  	return net.DialPingAddr(laddr, netip.Addr{})
 368  }
 369  
 370  func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) {
 371  	var la, ra netip.Addr
 372  	if laddr != nil {
 373  		la = laddr.addr
 374  	}
 375  	if raddr != nil {
 376  		ra = raddr.addr
 377  	}
 378  	return net.DialPingAddr(la, ra)
 379  }
 380  
 381  func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) {
 382  	var la netip.Addr
 383  	if laddr != nil {
 384  		la = laddr.addr
 385  	}
 386  	return net.ListenPingAddr(la)
 387  }
 388  
 389  func (pc *PingConn) LocalAddr() net.Addr {
 390  	return pc.laddr
 391  }
 392  
 393  func (pc *PingConn) RemoteAddr() net.Addr {
 394  	return pc.raddr
 395  }
 396  
 397  func (pc *PingConn) Close() error {
 398  	pc.deadline.Reset(0)
 399  	pc.ep.Close()
 400  	return nil
 401  }
 402  
 403  func (pc *PingConn) SetWriteDeadline(t time.Time) error {
 404  	return errors.New("not implemented")
 405  }
 406  
 407  func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
 408  	var na netip.Addr
 409  	switch v := addr.(type) {
 410  	case *PingAddr:
 411  		na = v.addr
 412  	case *net.IPAddr:
 413  		na, _ = netip.AddrFromSlice(v.IP)
 414  	default:
 415  		return 0, fmt.Errorf("ping write: wrong net.Addr type")
 416  	}
 417  	if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) {
 418  		return 0, fmt.Errorf("ping write: mismatched protocols")
 419  	}
 420  
 421  	buf := bytes.NewReader(p)
 422  	rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0))
 423  	// won't block, no deadlines
 424  	n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{
 425  		To: &rfa,
 426  	})
 427  	if tcpipErr != nil {
 428  		return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
 429  	}
 430  
 431  	return int(n64), nil
 432  }
 433  
 434  func (pc *PingConn) Write(p []byte) (n int, err error) {
 435  	return pc.WriteTo(p, &pc.raddr)
 436  }
 437  
 438  func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
 439  	e, notifyCh := waiter.NewChannelEntry(waiter.EventIn)
 440  	pc.wq.EventRegister(&e)
 441  	defer pc.wq.EventUnregister(&e)
 442  
 443  	select {
 444  	case <-pc.deadline.C:
 445  		return 0, nil, os.ErrDeadlineExceeded
 446  	case <-notifyCh:
 447  	}
 448  
 449  	w := tcpip.SliceWriter(p)
 450  
 451  	res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
 452  		NeedRemoteAddr: true,
 453  	})
 454  	if tcpipErr != nil {
 455  		return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
 456  	}
 457  
 458  	remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice())
 459  	return res.Count, &PingAddr{remoteAddr}, nil
 460  }
 461  
 462  func (pc *PingConn) Read(p []byte) (n int, err error) {
 463  	n, _, err = pc.ReadFrom(p)
 464  	return
 465  }
 466  
 467  func (pc *PingConn) SetDeadline(t time.Time) error {
 468  	// pc.SetWriteDeadline is unimplemented
 469  
 470  	return pc.SetReadDeadline(t)
 471  }
 472  
 473  func (pc *PingConn) SetReadDeadline(t time.Time) error {
 474  	pc.deadline.Reset(time.Until(t))
 475  	return nil
 476  }
 477  
 478  var (
 479  	errNoSuchHost                   = errors.New("no such host")
 480  	errLameReferral                 = errors.New("lame referral")
 481  	errCannotUnmarshalDNSMessage    = errors.New("cannot unmarshal DNS message")
 482  	errCannotMarshalDNSMessage      = errors.New("cannot marshal DNS message")
 483  	errServerMisbehaving            = errors.New("server misbehaving")
 484  	errInvalidDNSResponse           = errors.New("invalid DNS response")
 485  	errNoAnswerFromDNSServer        = errors.New("no answer from DNS server")
 486  	errServerTemporarilyMisbehaving = errors.New("server misbehaving")
 487  	errCanceled                     = errors.New("operation was canceled")
 488  	errTimeout                      = errors.New("i/o timeout")
 489  	errNumericPort                  = errors.New("port must be numeric")
 490  	errNoSuitableAddress            = errors.New("no suitable address found")
 491  	errMissingAddress               = errors.New("missing address")
 492  )
 493  
 494  func (net *Net) LookupHost(host string) (addrs []string, err error) {
 495  	return net.LookupContextHost(context.Background(), host)
 496  }
 497  
 498  func isDomainName(s string) bool {
 499  	l := len(s)
 500  	if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
 501  		return false
 502  	}
 503  	last := byte('.')
 504  	nonNumeric := false
 505  	partlen := 0
 506  	for i := 0; i < len(s); i++ {
 507  		c := s[i]
 508  		switch {
 509  		default:
 510  			return false
 511  		case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
 512  			nonNumeric = true
 513  			partlen++
 514  		case '0' <= c && c <= '9':
 515  			partlen++
 516  		case c == '-':
 517  			if last == '.' {
 518  				return false
 519  			}
 520  			partlen++
 521  			nonNumeric = true
 522  		case c == '.':
 523  			if last == '.' || last == '-' {
 524  				return false
 525  			}
 526  			if partlen > 63 || partlen == 0 {
 527  				return false
 528  			}
 529  			partlen = 0
 530  		}
 531  		last = c
 532  	}
 533  	if last == '-' || partlen > 63 {
 534  		return false
 535  	}
 536  	return nonNumeric
 537  }
 538  
 539  func randU16() uint16 {
 540  	var b [2]byte
 541  	_, err := rand.Read(b[:])
 542  	if err != nil {
 543  		panic(err)
 544  	}
 545  	return binary.LittleEndian.Uint16(b[:])
 546  }
 547  
 548  func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) {
 549  	id = randU16()
 550  	b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true})
 551  	b.EnableCompression()
 552  	if err := b.StartQuestions(); err != nil {
 553  		return 0, nil, nil, err
 554  	}
 555  	if err := b.Question(q); err != nil {
 556  		return 0, nil, nil, err
 557  	}
 558  	tcpReq, err = b.Finish()
 559  	udpReq = tcpReq[2:]
 560  	l := len(tcpReq) - 2
 561  	tcpReq[0] = byte(l >> 8)
 562  	tcpReq[1] = byte(l)
 563  	return id, udpReq, tcpReq, err
 564  }
 565  
 566  func equalASCIIName(x, y dnsmessage.Name) bool {
 567  	if x.Length != y.Length {
 568  		return false
 569  	}
 570  	for i := 0; i < int(x.Length); i++ {
 571  		a := x.Data[i]
 572  		b := y.Data[i]
 573  		if 'A' <= a && a <= 'Z' {
 574  			a += 0x20
 575  		}
 576  		if 'A' <= b && b <= 'Z' {
 577  			b += 0x20
 578  		}
 579  		if a != b {
 580  			return false
 581  		}
 582  	}
 583  	return true
 584  }
 585  
 586  func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
 587  	if !respHdr.Response {
 588  		return false
 589  	}
 590  	if reqID != respHdr.ID {
 591  		return false
 592  	}
 593  	if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
 594  		return false
 595  	}
 596  	return true
 597  }
 598  
 599  func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
 600  	if _, err := c.Write(b); err != nil {
 601  		return dnsmessage.Parser{}, dnsmessage.Header{}, err
 602  	}
 603  	b = make([]byte, 512)
 604  	for {
 605  		n, err := c.Read(b)
 606  		if err != nil {
 607  			return dnsmessage.Parser{}, dnsmessage.Header{}, err
 608  		}
 609  		var p dnsmessage.Parser
 610  		h, err := p.Start(b[:n])
 611  		if err != nil {
 612  			continue
 613  		}
 614  		q, err := p.Question()
 615  		if err != nil || !checkResponse(id, query, h, q) {
 616  			continue
 617  		}
 618  		return p, h, nil
 619  	}
 620  }
 621  
 622  func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
 623  	if _, err := c.Write(b); err != nil {
 624  		return dnsmessage.Parser{}, dnsmessage.Header{}, err
 625  	}
 626  	b = make([]byte, 1280)
 627  	if _, err := io.ReadFull(c, b[:2]); err != nil {
 628  		return dnsmessage.Parser{}, dnsmessage.Header{}, err
 629  	}
 630  	l := int(b[0])<<8 | int(b[1])
 631  	if l > len(b) {
 632  		b = make([]byte, l)
 633  	}
 634  	n, err := io.ReadFull(c, b[:l])
 635  	if err != nil {
 636  		return dnsmessage.Parser{}, dnsmessage.Header{}, err
 637  	}
 638  	var p dnsmessage.Parser
 639  	h, err := p.Start(b[:n])
 640  	if err != nil {
 641  		return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
 642  	}
 643  	q, err := p.Question()
 644  	if err != nil {
 645  		return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
 646  	}
 647  	if !checkResponse(id, query, h, q) {
 648  		return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
 649  	}
 650  	return p, h, nil
 651  }
 652  
 653  func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
 654  	q.Class = dnsmessage.ClassINET
 655  	id, udpReq, tcpReq, err := newRequest(q)
 656  	if err != nil {
 657  		return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
 658  	}
 659  
 660  	for _, useUDP := range []bool{true, false} {
 661  		ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
 662  		defer cancel()
 663  
 664  		var c net.Conn
 665  		var err error
 666  		if useUDP {
 667  			c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
 668  		} else {
 669  			c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
 670  		}
 671  
 672  		if err != nil {
 673  			return dnsmessage.Parser{}, dnsmessage.Header{}, err
 674  		}
 675  		if d, ok := ctx.Deadline(); ok && !d.IsZero() {
 676  			err := c.SetDeadline(d)
 677  			if err != nil {
 678  				return dnsmessage.Parser{}, dnsmessage.Header{}, err
 679  			}
 680  		}
 681  		var p dnsmessage.Parser
 682  		var h dnsmessage.Header
 683  		if useUDP {
 684  			p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
 685  		} else {
 686  			p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
 687  		}
 688  		c.Close()
 689  		if err != nil {
 690  			if err == context.Canceled {
 691  				err = errCanceled
 692  			} else if err == context.DeadlineExceeded {
 693  				err = errTimeout
 694  			}
 695  			return dnsmessage.Parser{}, dnsmessage.Header{}, err
 696  		}
 697  		if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
 698  			return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
 699  		}
 700  		if h.Truncated {
 701  			continue
 702  		}
 703  		return p, h, nil
 704  	}
 705  	return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer
 706  }
 707  
 708  func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
 709  	if h.RCode == dnsmessage.RCodeNameError {
 710  		return errNoSuchHost
 711  	}
 712  	_, err := p.AnswerHeader()
 713  	if err != nil && err != dnsmessage.ErrSectionDone {
 714  		return errCannotUnmarshalDNSMessage
 715  	}
 716  	if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
 717  		return errLameReferral
 718  	}
 719  	if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
 720  		if h.RCode == dnsmessage.RCodeServerFailure {
 721  			return errServerTemporarilyMisbehaving
 722  		}
 723  		return errServerMisbehaving
 724  	}
 725  	return nil
 726  }
 727  
 728  func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
 729  	for {
 730  		h, err := p.AnswerHeader()
 731  		if err == dnsmessage.ErrSectionDone {
 732  			return errNoSuchHost
 733  		}
 734  		if err != nil {
 735  			return errCannotUnmarshalDNSMessage
 736  		}
 737  		if h.Type == qtype {
 738  			return nil
 739  		}
 740  		if err := p.SkipAnswer(); err != nil {
 741  			return errCannotUnmarshalDNSMessage
 742  		}
 743  	}
 744  }
 745  
 746  func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
 747  	var lastErr error
 748  
 749  	n, err := dnsmessage.NewName(name)
 750  	if err != nil {
 751  		return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage
 752  	}
 753  	q := dnsmessage.Question{
 754  		Name:  n,
 755  		Type:  qtype,
 756  		Class: dnsmessage.ClassINET,
 757  	}
 758  
 759  	for i := 0; i < 2; i++ {
 760  		for _, server := range tnet.dnsServers {
 761  			p, h, err := tnet.exchange(ctx, server, q, time.Second*5)
 762  			if err != nil {
 763  				dnsErr := &net.DNSError{
 764  					Err:    err.Error(),
 765  					Name:   name,
 766  					Server: server.String(),
 767  				}
 768  				if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
 769  					dnsErr.IsTimeout = true
 770  				}
 771  				if _, ok := err.(*net.OpError); ok {
 772  					dnsErr.IsTemporary = true
 773  				}
 774  				lastErr = dnsErr
 775  				continue
 776  			}
 777  
 778  			if err := checkHeader(&p, h); err != nil {
 779  				dnsErr := &net.DNSError{
 780  					Err:    err.Error(),
 781  					Name:   name,
 782  					Server: server.String(),
 783  				}
 784  				if err == errServerTemporarilyMisbehaving {
 785  					dnsErr.IsTemporary = true
 786  				}
 787  				if err == errNoSuchHost {
 788  					dnsErr.IsNotFound = true
 789  					return p, server.String(), dnsErr
 790  				}
 791  				lastErr = dnsErr
 792  				continue
 793  			}
 794  
 795  			err = skipToAnswer(&p, qtype)
 796  			if err == nil {
 797  				return p, server.String(), nil
 798  			}
 799  			lastErr = &net.DNSError{
 800  				Err:    err.Error(),
 801  				Name:   name,
 802  				Server: server.String(),
 803  			}
 804  			if err == errNoSuchHost {
 805  				lastErr.(*net.DNSError).IsNotFound = true
 806  				return p, server.String(), lastErr
 807  			}
 808  		}
 809  	}
 810  	return dnsmessage.Parser{}, "", lastErr
 811  }
 812  
 813  func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) {
 814  	if host == "" || (!tnet.hasV6 && !tnet.hasV4) {
 815  		return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
 816  	}
 817  	zlen := len(host)
 818  	if strings.IndexByte(host, ':') != -1 {
 819  		if zidx := strings.LastIndexByte(host, '%'); zidx != -1 {
 820  			zlen = zidx
 821  		}
 822  	}
 823  	if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
 824  		return []string{ip.String()}, nil
 825  	}
 826  
 827  	if !isDomainName(host) {
 828  		return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
 829  	}
 830  	type result struct {
 831  		p      dnsmessage.Parser
 832  		server string
 833  		error
 834  	}
 835  	var addrsV4, addrsV6 []netip.Addr
 836  	lanes := 0
 837  	if tnet.hasV4 {
 838  		lanes++
 839  	}
 840  	if tnet.hasV6 {
 841  		lanes++
 842  	}
 843  	lane := make(chan result, lanes)
 844  	var lastErr error
 845  	if tnet.hasV4 {
 846  		go func() {
 847  			p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA)
 848  			lane <- result{p, server, err}
 849  		}()
 850  	}
 851  	if tnet.hasV6 {
 852  		go func() {
 853  			p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA)
 854  			lane <- result{p, server, err}
 855  		}()
 856  	}
 857  	for l := 0; l < lanes; l++ {
 858  		result := <-lane
 859  		if result.error != nil {
 860  			if lastErr == nil {
 861  				lastErr = result.error
 862  			}
 863  			continue
 864  		}
 865  
 866  	loop:
 867  		for {
 868  			h, err := result.p.AnswerHeader()
 869  			if err != nil && err != dnsmessage.ErrSectionDone {
 870  				lastErr = &net.DNSError{
 871  					Err:    errCannotMarshalDNSMessage.Error(),
 872  					Name:   host,
 873  					Server: result.server,
 874  				}
 875  			}
 876  			if err != nil {
 877  				break
 878  			}
 879  			switch h.Type {
 880  			case dnsmessage.TypeA:
 881  				a, err := result.p.AResource()
 882  				if err != nil {
 883  					lastErr = &net.DNSError{
 884  						Err:    errCannotMarshalDNSMessage.Error(),
 885  						Name:   host,
 886  						Server: result.server,
 887  					}
 888  					break loop
 889  				}
 890  				addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
 891  
 892  			case dnsmessage.TypeAAAA:
 893  				aaaa, err := result.p.AAAAResource()
 894  				if err != nil {
 895  					lastErr = &net.DNSError{
 896  						Err:    errCannotMarshalDNSMessage.Error(),
 897  						Name:   host,
 898  						Server: result.server,
 899  					}
 900  					break loop
 901  				}
 902  				addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
 903  
 904  			default:
 905  				if err := result.p.SkipAnswer(); err != nil {
 906  					lastErr = &net.DNSError{
 907  						Err:    errCannotMarshalDNSMessage.Error(),
 908  						Name:   host,
 909  						Server: result.server,
 910  					}
 911  					break loop
 912  				}
 913  				continue
 914  			}
 915  		}
 916  	}
 917  	// We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled
 918  	var addrs []netip.Addr
 919  	if tnet.hasV6 {
 920  		addrs = append(addrsV6, addrsV4...)
 921  	} else {
 922  		addrs = append(addrsV4, addrsV6...)
 923  	}
 924  
 925  	if len(addrs) == 0 && lastErr != nil {
 926  		return nil, lastErr
 927  	}
 928  	saddrs := make([]string, 0, len(addrs))
 929  	for _, ip := range addrs {
 930  		saddrs = append(saddrs, ip.String())
 931  	}
 932  	return saddrs, nil
 933  }
 934  
 935  func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
 936  	if deadline.IsZero() {
 937  		return deadline, nil
 938  	}
 939  	timeRemaining := deadline.Sub(now)
 940  	if timeRemaining <= 0 {
 941  		return time.Time{}, errTimeout
 942  	}
 943  	timeout := timeRemaining / time.Duration(addrsRemaining)
 944  	const saneMinimum = 2 * time.Second
 945  	if timeout < saneMinimum {
 946  		if timeRemaining < saneMinimum {
 947  			timeout = timeRemaining
 948  		} else {
 949  			timeout = saneMinimum
 950  		}
 951  	}
 952  	return now.Add(timeout), nil
 953  }
 954  
 955  var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
 956  
 957  func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
 958  	if ctx == nil {
 959  		panic("nil context")
 960  	}
 961  	var acceptV4, acceptV6 bool
 962  	matches := protoSplitter.FindStringSubmatch(network)
 963  	if matches == nil {
 964  		return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
 965  	} else if len(matches[2]) == 0 {
 966  		acceptV4 = true
 967  		acceptV6 = true
 968  	} else {
 969  		acceptV4 = matches[2][0] == '4'
 970  		acceptV6 = !acceptV4
 971  	}
 972  	var host string
 973  	var port int
 974  	if matches[1] == "ping" {
 975  		host = address
 976  	} else {
 977  		var sport string
 978  		var err error
 979  		host, sport, err = net.SplitHostPort(address)
 980  		if err != nil {
 981  			return nil, &net.OpError{Op: "dial", Err: err}
 982  		}
 983  		port, err = strconv.Atoi(sport)
 984  		if err != nil || port < 0 || port > 65535 {
 985  			return nil, &net.OpError{Op: "dial", Err: errNumericPort}
 986  		}
 987  	}
 988  	allAddr, err := tnet.LookupContextHost(ctx, host)
 989  	if err != nil {
 990  		return nil, &net.OpError{Op: "dial", Err: err}
 991  	}
 992  	var addrs []netip.AddrPort
 993  	for _, addr := range allAddr {
 994  		ip, err := netip.ParseAddr(addr)
 995  		if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
 996  			addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
 997  		}
 998  	}
 999  	if len(addrs) == 0 && len(allAddr) != 0 {
1000  		return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress}
1001  	}
1002  
1003  	var firstErr error
1004  	for i, addr := range addrs {
1005  		select {
1006  		case <-ctx.Done():
1007  			err := ctx.Err()
1008  			if err == context.Canceled {
1009  				err = errCanceled
1010  			} else if err == context.DeadlineExceeded {
1011  				err = errTimeout
1012  			}
1013  			return nil, &net.OpError{Op: "dial", Err: err}
1014  		default:
1015  		}
1016  
1017  		dialCtx := ctx
1018  		if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
1019  			partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i)
1020  			if err != nil {
1021  				if firstErr == nil {
1022  					firstErr = &net.OpError{Op: "dial", Err: err}
1023  				}
1024  				break
1025  			}
1026  			if partialDeadline.Before(deadline) {
1027  				var cancel context.CancelFunc
1028  				dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
1029  				defer cancel()
1030  			}
1031  		}
1032  
1033  		var c net.Conn
1034  		switch matches[1] {
1035  		case "tcp":
1036  			c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
1037  		case "udp":
1038  			c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
1039  		case "ping":
1040  			c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
1041  		}
1042  		if err == nil {
1043  			return c, nil
1044  		}
1045  		if firstErr == nil {
1046  			firstErr = err
1047  		}
1048  	}
1049  	if firstErr == nil {
1050  		firstErr = &net.OpError{Op: "dial", Err: errMissingAddress}
1051  	}
1052  	return nil, firstErr
1053  }
1054  
1055  func (tnet *Net) Dial(network, address string) (net.Conn, error) {
1056  	return tnet.DialContext(context.Background(), network, address)
1057  }
1058