tun.go raw
1 /* SPDX-License-Identifier: MIT
2 *
3 * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
4 */
5
6 package netstack
7
8 import (
9 "bytes"
10 "context"
11 "crypto/rand"
12 "encoding/binary"
13 "errors"
14 "fmt"
15 "io"
16 "net"
17 "net/netip"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "syscall"
23 "time"
24
25 "golang.zx2c4.com/wireguard/tun"
26
27 "golang.org/x/net/dns/dnsmessage"
28 "gvisor.dev/gvisor/pkg/buffer"
29 "gvisor.dev/gvisor/pkg/tcpip"
30 "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
31 "gvisor.dev/gvisor/pkg/tcpip/header"
32 "gvisor.dev/gvisor/pkg/tcpip/link/channel"
33 "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
34 "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
35 "gvisor.dev/gvisor/pkg/tcpip/stack"
36 "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
37 "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
38 "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
39 "gvisor.dev/gvisor/pkg/waiter"
40 )
41
42 type netTun struct {
43 ep *channel.Endpoint
44 stack *stack.Stack
45 events chan tun.Event
46 notifyHandle *channel.NotificationHandle
47 incomingPacket chan *buffer.View
48 mtu int
49 dnsServers []netip.Addr
50 hasV4, hasV6 bool
51 }
52
53 type Net netTun
54
55 func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
56 opts := stack.Options{
57 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
58 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
59 HandleLocal: true,
60 }
61 dev := &netTun{
62 ep: channel.New(1024, uint32(mtu), ""),
63 stack: stack.New(opts),
64 events: make(chan tun.Event, 10),
65 incomingPacket: make(chan *buffer.View),
66 dnsServers: dnsServers,
67 mtu: mtu,
68 }
69 sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
70 tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
71 if tcpipErr != nil {
72 return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
73 }
74 dev.notifyHandle = dev.ep.AddNotify(dev)
75 tcpipErr = dev.stack.CreateNIC(1, dev.ep)
76 if tcpipErr != nil {
77 return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
78 }
79 for _, ip := range localAddresses {
80 var protoNumber tcpip.NetworkProtocolNumber
81 if ip.Is4() {
82 protoNumber = ipv4.ProtocolNumber
83 } else if ip.Is6() {
84 protoNumber = ipv6.ProtocolNumber
85 }
86 protoAddr := tcpip.ProtocolAddress{
87 Protocol: protoNumber,
88 AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
89 }
90 tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
91 if tcpipErr != nil {
92 return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
93 }
94 if ip.Is4() {
95 dev.hasV4 = true
96 } else if ip.Is6() {
97 dev.hasV6 = true
98 }
99 }
100 if dev.hasV4 {
101 dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
102 }
103 if dev.hasV6 {
104 dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
105 }
106
107 dev.events <- tun.EventUp
108 return dev, (*Net)(dev), nil
109 }
110
111 func (tun *netTun) Name() (string, error) {
112 return "go", nil
113 }
114
115 func (tun *netTun) File() *os.File {
116 return nil
117 }
118
119 func (tun *netTun) Events() <-chan tun.Event {
120 return tun.events
121 }
122
123 func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
124 view, ok := <-tun.incomingPacket
125 if !ok {
126 return 0, os.ErrClosed
127 }
128
129 n, err := view.Read(buf[0][offset:])
130 if err != nil {
131 return 0, err
132 }
133 sizes[0] = n
134 return 1, nil
135 }
136
137 func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
138 for _, buf := range buf {
139 packet := buf[offset:]
140 if len(packet) == 0 {
141 continue
142 }
143
144 pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
145 switch packet[0] >> 4 {
146 case 4:
147 tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
148 case 6:
149 tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
150 default:
151 return 0, syscall.EAFNOSUPPORT
152 }
153 }
154 return len(buf), nil
155 }
156
157 func (tun *netTun) WriteNotify() {
158 pkt := tun.ep.Read()
159 if pkt == nil {
160 return
161 }
162
163 view := pkt.ToView()
164 pkt.DecRef()
165
166 tun.incomingPacket <- view
167 }
168
169 func (tun *netTun) Close() error {
170 tun.stack.RemoveNIC(1)
171 tun.stack.Close()
172 tun.ep.RemoveNotify(tun.notifyHandle)
173 tun.ep.Close()
174
175 if tun.events != nil {
176 close(tun.events)
177 }
178
179 if tun.incomingPacket != nil {
180 close(tun.incomingPacket)
181 }
182
183 return nil
184 }
185
186 func (tun *netTun) MTU() (int, error) {
187 return tun.mtu, nil
188 }
189
190 func (tun *netTun) BatchSize() int {
191 return 1
192 }
193
194 func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
195 var protoNumber tcpip.NetworkProtocolNumber
196 if endpoint.Addr().Is4() {
197 protoNumber = ipv4.ProtocolNumber
198 } else {
199 protoNumber = ipv6.ProtocolNumber
200 }
201 return tcpip.FullAddress{
202 NIC: 1,
203 Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
204 Port: endpoint.Port(),
205 }, protoNumber
206 }
207
208 func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
209 fa, pn := convertToFullAddr(addr)
210 return gonet.DialContextTCP(ctx, net.stack, fa, pn)
211 }
212
213 func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
214 if addr == nil {
215 return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
216 }
217 ip, _ := netip.AddrFromSlice(addr.IP)
218 return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
219 }
220
221 func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
222 fa, pn := convertToFullAddr(addr)
223 return gonet.DialTCP(net.stack, fa, pn)
224 }
225
226 func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
227 if addr == nil {
228 return net.DialTCPAddrPort(netip.AddrPort{})
229 }
230 ip, _ := netip.AddrFromSlice(addr.IP)
231 return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
232 }
233
234 func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
235 fa, pn := convertToFullAddr(addr)
236 return gonet.ListenTCP(net.stack, fa, pn)
237 }
238
239 func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
240 if addr == nil {
241 return net.ListenTCPAddrPort(netip.AddrPort{})
242 }
243 ip, _ := netip.AddrFromSlice(addr.IP)
244 return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
245 }
246
247 func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
248 var lfa, rfa *tcpip.FullAddress
249 var pn tcpip.NetworkProtocolNumber
250 if laddr.IsValid() || laddr.Port() > 0 {
251 var addr tcpip.FullAddress
252 addr, pn = convertToFullAddr(laddr)
253 lfa = &addr
254 }
255 if raddr.IsValid() || raddr.Port() > 0 {
256 var addr tcpip.FullAddress
257 addr, pn = convertToFullAddr(raddr)
258 rfa = &addr
259 }
260 return gonet.DialUDP(net.stack, lfa, rfa, pn)
261 }
262
263 func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
264 return net.DialUDPAddrPort(laddr, netip.AddrPort{})
265 }
266
267 func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
268 var la, ra netip.AddrPort
269 if laddr != nil {
270 ip, _ := netip.AddrFromSlice(laddr.IP)
271 la = netip.AddrPortFrom(ip, uint16(laddr.Port))
272 }
273 if raddr != nil {
274 ip, _ := netip.AddrFromSlice(raddr.IP)
275 ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
276 }
277 return net.DialUDPAddrPort(la, ra)
278 }
279
280 func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
281 return net.DialUDP(laddr, nil)
282 }
283
284 type PingConn struct {
285 laddr PingAddr
286 raddr PingAddr
287 wq waiter.Queue
288 ep tcpip.Endpoint
289 deadline *time.Timer
290 }
291
292 type PingAddr struct{ addr netip.Addr }
293
294 func (ia PingAddr) String() string {
295 return ia.addr.String()
296 }
297
298 func (ia PingAddr) Network() string {
299 if ia.addr.Is4() {
300 return "ping4"
301 } else if ia.addr.Is6() {
302 return "ping6"
303 }
304 return "ping"
305 }
306
307 func (ia PingAddr) Addr() netip.Addr {
308 return ia.addr
309 }
310
311 func PingAddrFromAddr(addr netip.Addr) *PingAddr {
312 return &PingAddr{addr}
313 }
314
315 func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
316 if !laddr.IsValid() && !raddr.IsValid() {
317 return nil, errors.New("ping dial: invalid address")
318 }
319 v6 := laddr.Is6() || raddr.Is6()
320 bind := laddr.IsValid()
321 if !bind {
322 if v6 {
323 laddr = netip.IPv6Unspecified()
324 } else {
325 laddr = netip.IPv4Unspecified()
326 }
327 }
328
329 tn := icmp.ProtocolNumber4
330 pn := ipv4.ProtocolNumber
331 if v6 {
332 tn = icmp.ProtocolNumber6
333 pn = ipv6.ProtocolNumber
334 }
335
336 pc := &PingConn{
337 laddr: PingAddr{laddr},
338 deadline: time.NewTimer(time.Hour << 10),
339 }
340 pc.deadline.Stop()
341
342 ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
343 if tcpipErr != nil {
344 return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
345 }
346 pc.ep = ep
347
348 if bind {
349 fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
350 if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
351 return nil, fmt.Errorf("ping bind: %s", tcpipErr)
352 }
353 }
354
355 if raddr.IsValid() {
356 pc.raddr = PingAddr{raddr}
357 fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
358 if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
359 return nil, fmt.Errorf("ping connect: %s", tcpipErr)
360 }
361 }
362
363 return pc, nil
364 }
365
366 func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) {
367 return net.DialPingAddr(laddr, netip.Addr{})
368 }
369
370 func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) {
371 var la, ra netip.Addr
372 if laddr != nil {
373 la = laddr.addr
374 }
375 if raddr != nil {
376 ra = raddr.addr
377 }
378 return net.DialPingAddr(la, ra)
379 }
380
381 func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) {
382 var la netip.Addr
383 if laddr != nil {
384 la = laddr.addr
385 }
386 return net.ListenPingAddr(la)
387 }
388
389 func (pc *PingConn) LocalAddr() net.Addr {
390 return pc.laddr
391 }
392
393 func (pc *PingConn) RemoteAddr() net.Addr {
394 return pc.raddr
395 }
396
397 func (pc *PingConn) Close() error {
398 pc.deadline.Reset(0)
399 pc.ep.Close()
400 return nil
401 }
402
403 func (pc *PingConn) SetWriteDeadline(t time.Time) error {
404 return errors.New("not implemented")
405 }
406
407 func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
408 var na netip.Addr
409 switch v := addr.(type) {
410 case *PingAddr:
411 na = v.addr
412 case *net.IPAddr:
413 na, _ = netip.AddrFromSlice(v.IP)
414 default:
415 return 0, fmt.Errorf("ping write: wrong net.Addr type")
416 }
417 if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) {
418 return 0, fmt.Errorf("ping write: mismatched protocols")
419 }
420
421 buf := bytes.NewReader(p)
422 rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0))
423 // won't block, no deadlines
424 n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{
425 To: &rfa,
426 })
427 if tcpipErr != nil {
428 return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
429 }
430
431 return int(n64), nil
432 }
433
434 func (pc *PingConn) Write(p []byte) (n int, err error) {
435 return pc.WriteTo(p, &pc.raddr)
436 }
437
438 func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
439 e, notifyCh := waiter.NewChannelEntry(waiter.EventIn)
440 pc.wq.EventRegister(&e)
441 defer pc.wq.EventUnregister(&e)
442
443 select {
444 case <-pc.deadline.C:
445 return 0, nil, os.ErrDeadlineExceeded
446 case <-notifyCh:
447 }
448
449 w := tcpip.SliceWriter(p)
450
451 res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
452 NeedRemoteAddr: true,
453 })
454 if tcpipErr != nil {
455 return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
456 }
457
458 remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice())
459 return res.Count, &PingAddr{remoteAddr}, nil
460 }
461
462 func (pc *PingConn) Read(p []byte) (n int, err error) {
463 n, _, err = pc.ReadFrom(p)
464 return
465 }
466
467 func (pc *PingConn) SetDeadline(t time.Time) error {
468 // pc.SetWriteDeadline is unimplemented
469
470 return pc.SetReadDeadline(t)
471 }
472
473 func (pc *PingConn) SetReadDeadline(t time.Time) error {
474 pc.deadline.Reset(time.Until(t))
475 return nil
476 }
477
478 var (
479 errNoSuchHost = errors.New("no such host")
480 errLameReferral = errors.New("lame referral")
481 errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message")
482 errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message")
483 errServerMisbehaving = errors.New("server misbehaving")
484 errInvalidDNSResponse = errors.New("invalid DNS response")
485 errNoAnswerFromDNSServer = errors.New("no answer from DNS server")
486 errServerTemporarilyMisbehaving = errors.New("server misbehaving")
487 errCanceled = errors.New("operation was canceled")
488 errTimeout = errors.New("i/o timeout")
489 errNumericPort = errors.New("port must be numeric")
490 errNoSuitableAddress = errors.New("no suitable address found")
491 errMissingAddress = errors.New("missing address")
492 )
493
494 func (net *Net) LookupHost(host string) (addrs []string, err error) {
495 return net.LookupContextHost(context.Background(), host)
496 }
497
498 func isDomainName(s string) bool {
499 l := len(s)
500 if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
501 return false
502 }
503 last := byte('.')
504 nonNumeric := false
505 partlen := 0
506 for i := 0; i < len(s); i++ {
507 c := s[i]
508 switch {
509 default:
510 return false
511 case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
512 nonNumeric = true
513 partlen++
514 case '0' <= c && c <= '9':
515 partlen++
516 case c == '-':
517 if last == '.' {
518 return false
519 }
520 partlen++
521 nonNumeric = true
522 case c == '.':
523 if last == '.' || last == '-' {
524 return false
525 }
526 if partlen > 63 || partlen == 0 {
527 return false
528 }
529 partlen = 0
530 }
531 last = c
532 }
533 if last == '-' || partlen > 63 {
534 return false
535 }
536 return nonNumeric
537 }
538
539 func randU16() uint16 {
540 var b [2]byte
541 _, err := rand.Read(b[:])
542 if err != nil {
543 panic(err)
544 }
545 return binary.LittleEndian.Uint16(b[:])
546 }
547
548 func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) {
549 id = randU16()
550 b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true})
551 b.EnableCompression()
552 if err := b.StartQuestions(); err != nil {
553 return 0, nil, nil, err
554 }
555 if err := b.Question(q); err != nil {
556 return 0, nil, nil, err
557 }
558 tcpReq, err = b.Finish()
559 udpReq = tcpReq[2:]
560 l := len(tcpReq) - 2
561 tcpReq[0] = byte(l >> 8)
562 tcpReq[1] = byte(l)
563 return id, udpReq, tcpReq, err
564 }
565
566 func equalASCIIName(x, y dnsmessage.Name) bool {
567 if x.Length != y.Length {
568 return false
569 }
570 for i := 0; i < int(x.Length); i++ {
571 a := x.Data[i]
572 b := y.Data[i]
573 if 'A' <= a && a <= 'Z' {
574 a += 0x20
575 }
576 if 'A' <= b && b <= 'Z' {
577 b += 0x20
578 }
579 if a != b {
580 return false
581 }
582 }
583 return true
584 }
585
586 func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
587 if !respHdr.Response {
588 return false
589 }
590 if reqID != respHdr.ID {
591 return false
592 }
593 if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
594 return false
595 }
596 return true
597 }
598
599 func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
600 if _, err := c.Write(b); err != nil {
601 return dnsmessage.Parser{}, dnsmessage.Header{}, err
602 }
603 b = make([]byte, 512)
604 for {
605 n, err := c.Read(b)
606 if err != nil {
607 return dnsmessage.Parser{}, dnsmessage.Header{}, err
608 }
609 var p dnsmessage.Parser
610 h, err := p.Start(b[:n])
611 if err != nil {
612 continue
613 }
614 q, err := p.Question()
615 if err != nil || !checkResponse(id, query, h, q) {
616 continue
617 }
618 return p, h, nil
619 }
620 }
621
622 func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
623 if _, err := c.Write(b); err != nil {
624 return dnsmessage.Parser{}, dnsmessage.Header{}, err
625 }
626 b = make([]byte, 1280)
627 if _, err := io.ReadFull(c, b[:2]); err != nil {
628 return dnsmessage.Parser{}, dnsmessage.Header{}, err
629 }
630 l := int(b[0])<<8 | int(b[1])
631 if l > len(b) {
632 b = make([]byte, l)
633 }
634 n, err := io.ReadFull(c, b[:l])
635 if err != nil {
636 return dnsmessage.Parser{}, dnsmessage.Header{}, err
637 }
638 var p dnsmessage.Parser
639 h, err := p.Start(b[:n])
640 if err != nil {
641 return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
642 }
643 q, err := p.Question()
644 if err != nil {
645 return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
646 }
647 if !checkResponse(id, query, h, q) {
648 return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
649 }
650 return p, h, nil
651 }
652
653 func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
654 q.Class = dnsmessage.ClassINET
655 id, udpReq, tcpReq, err := newRequest(q)
656 if err != nil {
657 return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
658 }
659
660 for _, useUDP := range []bool{true, false} {
661 ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
662 defer cancel()
663
664 var c net.Conn
665 var err error
666 if useUDP {
667 c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
668 } else {
669 c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
670 }
671
672 if err != nil {
673 return dnsmessage.Parser{}, dnsmessage.Header{}, err
674 }
675 if d, ok := ctx.Deadline(); ok && !d.IsZero() {
676 err := c.SetDeadline(d)
677 if err != nil {
678 return dnsmessage.Parser{}, dnsmessage.Header{}, err
679 }
680 }
681 var p dnsmessage.Parser
682 var h dnsmessage.Header
683 if useUDP {
684 p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
685 } else {
686 p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
687 }
688 c.Close()
689 if err != nil {
690 if err == context.Canceled {
691 err = errCanceled
692 } else if err == context.DeadlineExceeded {
693 err = errTimeout
694 }
695 return dnsmessage.Parser{}, dnsmessage.Header{}, err
696 }
697 if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
698 return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
699 }
700 if h.Truncated {
701 continue
702 }
703 return p, h, nil
704 }
705 return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer
706 }
707
708 func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
709 if h.RCode == dnsmessage.RCodeNameError {
710 return errNoSuchHost
711 }
712 _, err := p.AnswerHeader()
713 if err != nil && err != dnsmessage.ErrSectionDone {
714 return errCannotUnmarshalDNSMessage
715 }
716 if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
717 return errLameReferral
718 }
719 if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
720 if h.RCode == dnsmessage.RCodeServerFailure {
721 return errServerTemporarilyMisbehaving
722 }
723 return errServerMisbehaving
724 }
725 return nil
726 }
727
728 func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
729 for {
730 h, err := p.AnswerHeader()
731 if err == dnsmessage.ErrSectionDone {
732 return errNoSuchHost
733 }
734 if err != nil {
735 return errCannotUnmarshalDNSMessage
736 }
737 if h.Type == qtype {
738 return nil
739 }
740 if err := p.SkipAnswer(); err != nil {
741 return errCannotUnmarshalDNSMessage
742 }
743 }
744 }
745
746 func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
747 var lastErr error
748
749 n, err := dnsmessage.NewName(name)
750 if err != nil {
751 return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage
752 }
753 q := dnsmessage.Question{
754 Name: n,
755 Type: qtype,
756 Class: dnsmessage.ClassINET,
757 }
758
759 for i := 0; i < 2; i++ {
760 for _, server := range tnet.dnsServers {
761 p, h, err := tnet.exchange(ctx, server, q, time.Second*5)
762 if err != nil {
763 dnsErr := &net.DNSError{
764 Err: err.Error(),
765 Name: name,
766 Server: server.String(),
767 }
768 if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
769 dnsErr.IsTimeout = true
770 }
771 if _, ok := err.(*net.OpError); ok {
772 dnsErr.IsTemporary = true
773 }
774 lastErr = dnsErr
775 continue
776 }
777
778 if err := checkHeader(&p, h); err != nil {
779 dnsErr := &net.DNSError{
780 Err: err.Error(),
781 Name: name,
782 Server: server.String(),
783 }
784 if err == errServerTemporarilyMisbehaving {
785 dnsErr.IsTemporary = true
786 }
787 if err == errNoSuchHost {
788 dnsErr.IsNotFound = true
789 return p, server.String(), dnsErr
790 }
791 lastErr = dnsErr
792 continue
793 }
794
795 err = skipToAnswer(&p, qtype)
796 if err == nil {
797 return p, server.String(), nil
798 }
799 lastErr = &net.DNSError{
800 Err: err.Error(),
801 Name: name,
802 Server: server.String(),
803 }
804 if err == errNoSuchHost {
805 lastErr.(*net.DNSError).IsNotFound = true
806 return p, server.String(), lastErr
807 }
808 }
809 }
810 return dnsmessage.Parser{}, "", lastErr
811 }
812
813 func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) {
814 if host == "" || (!tnet.hasV6 && !tnet.hasV4) {
815 return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
816 }
817 zlen := len(host)
818 if strings.IndexByte(host, ':') != -1 {
819 if zidx := strings.LastIndexByte(host, '%'); zidx != -1 {
820 zlen = zidx
821 }
822 }
823 if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
824 return []string{ip.String()}, nil
825 }
826
827 if !isDomainName(host) {
828 return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
829 }
830 type result struct {
831 p dnsmessage.Parser
832 server string
833 error
834 }
835 var addrsV4, addrsV6 []netip.Addr
836 lanes := 0
837 if tnet.hasV4 {
838 lanes++
839 }
840 if tnet.hasV6 {
841 lanes++
842 }
843 lane := make(chan result, lanes)
844 var lastErr error
845 if tnet.hasV4 {
846 go func() {
847 p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA)
848 lane <- result{p, server, err}
849 }()
850 }
851 if tnet.hasV6 {
852 go func() {
853 p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA)
854 lane <- result{p, server, err}
855 }()
856 }
857 for l := 0; l < lanes; l++ {
858 result := <-lane
859 if result.error != nil {
860 if lastErr == nil {
861 lastErr = result.error
862 }
863 continue
864 }
865
866 loop:
867 for {
868 h, err := result.p.AnswerHeader()
869 if err != nil && err != dnsmessage.ErrSectionDone {
870 lastErr = &net.DNSError{
871 Err: errCannotMarshalDNSMessage.Error(),
872 Name: host,
873 Server: result.server,
874 }
875 }
876 if err != nil {
877 break
878 }
879 switch h.Type {
880 case dnsmessage.TypeA:
881 a, err := result.p.AResource()
882 if err != nil {
883 lastErr = &net.DNSError{
884 Err: errCannotMarshalDNSMessage.Error(),
885 Name: host,
886 Server: result.server,
887 }
888 break loop
889 }
890 addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
891
892 case dnsmessage.TypeAAAA:
893 aaaa, err := result.p.AAAAResource()
894 if err != nil {
895 lastErr = &net.DNSError{
896 Err: errCannotMarshalDNSMessage.Error(),
897 Name: host,
898 Server: result.server,
899 }
900 break loop
901 }
902 addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
903
904 default:
905 if err := result.p.SkipAnswer(); err != nil {
906 lastErr = &net.DNSError{
907 Err: errCannotMarshalDNSMessage.Error(),
908 Name: host,
909 Server: result.server,
910 }
911 break loop
912 }
913 continue
914 }
915 }
916 }
917 // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled
918 var addrs []netip.Addr
919 if tnet.hasV6 {
920 addrs = append(addrsV6, addrsV4...)
921 } else {
922 addrs = append(addrsV4, addrsV6...)
923 }
924
925 if len(addrs) == 0 && lastErr != nil {
926 return nil, lastErr
927 }
928 saddrs := make([]string, 0, len(addrs))
929 for _, ip := range addrs {
930 saddrs = append(saddrs, ip.String())
931 }
932 return saddrs, nil
933 }
934
935 func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
936 if deadline.IsZero() {
937 return deadline, nil
938 }
939 timeRemaining := deadline.Sub(now)
940 if timeRemaining <= 0 {
941 return time.Time{}, errTimeout
942 }
943 timeout := timeRemaining / time.Duration(addrsRemaining)
944 const saneMinimum = 2 * time.Second
945 if timeout < saneMinimum {
946 if timeRemaining < saneMinimum {
947 timeout = timeRemaining
948 } else {
949 timeout = saneMinimum
950 }
951 }
952 return now.Add(timeout), nil
953 }
954
955 var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
956
957 func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
958 if ctx == nil {
959 panic("nil context")
960 }
961 var acceptV4, acceptV6 bool
962 matches := protoSplitter.FindStringSubmatch(network)
963 if matches == nil {
964 return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
965 } else if len(matches[2]) == 0 {
966 acceptV4 = true
967 acceptV6 = true
968 } else {
969 acceptV4 = matches[2][0] == '4'
970 acceptV6 = !acceptV4
971 }
972 var host string
973 var port int
974 if matches[1] == "ping" {
975 host = address
976 } else {
977 var sport string
978 var err error
979 host, sport, err = net.SplitHostPort(address)
980 if err != nil {
981 return nil, &net.OpError{Op: "dial", Err: err}
982 }
983 port, err = strconv.Atoi(sport)
984 if err != nil || port < 0 || port > 65535 {
985 return nil, &net.OpError{Op: "dial", Err: errNumericPort}
986 }
987 }
988 allAddr, err := tnet.LookupContextHost(ctx, host)
989 if err != nil {
990 return nil, &net.OpError{Op: "dial", Err: err}
991 }
992 var addrs []netip.AddrPort
993 for _, addr := range allAddr {
994 ip, err := netip.ParseAddr(addr)
995 if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
996 addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
997 }
998 }
999 if len(addrs) == 0 && len(allAddr) != 0 {
1000 return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress}
1001 }
1002
1003 var firstErr error
1004 for i, addr := range addrs {
1005 select {
1006 case <-ctx.Done():
1007 err := ctx.Err()
1008 if err == context.Canceled {
1009 err = errCanceled
1010 } else if err == context.DeadlineExceeded {
1011 err = errTimeout
1012 }
1013 return nil, &net.OpError{Op: "dial", Err: err}
1014 default:
1015 }
1016
1017 dialCtx := ctx
1018 if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
1019 partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i)
1020 if err != nil {
1021 if firstErr == nil {
1022 firstErr = &net.OpError{Op: "dial", Err: err}
1023 }
1024 break
1025 }
1026 if partialDeadline.Before(deadline) {
1027 var cancel context.CancelFunc
1028 dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
1029 defer cancel()
1030 }
1031 }
1032
1033 var c net.Conn
1034 switch matches[1] {
1035 case "tcp":
1036 c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
1037 case "udp":
1038 c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
1039 case "ping":
1040 c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
1041 }
1042 if err == nil {
1043 return c, nil
1044 }
1045 if firstErr == nil {
1046 firstErr = err
1047 }
1048 }
1049 if firstErr == nil {
1050 firstErr = &net.OpError{Op: "dial", Err: errMissingAddress}
1051 }
1052 return nil, firstErr
1053 }
1054
1055 func (tnet *Net) Dial(network, address string) (net.Conn, error) {
1056 return tnet.DialContext(context.Background(), network, address)
1057 }
1058