client.go raw

   1  package dns
   2  
   3  // A client implementation.
   4  
   5  import (
   6  	"context"
   7  	"crypto/tls"
   8  	"encoding/binary"
   9  	"io"
  10  	"net"
  11  	"strings"
  12  	"time"
  13  )
  14  
  15  const (
  16  	dnsTimeout     time.Duration = 2 * time.Second
  17  	tcpIdleTimeout time.Duration = 8 * time.Second
  18  )
  19  
  20  func isPacketConn(c net.Conn) bool {
  21  	if _, ok := c.(net.PacketConn); !ok {
  22  		return false
  23  	}
  24  
  25  	if ua, ok := c.LocalAddr().(*net.UnixAddr); ok {
  26  		return ua.Net == "unixgram" || ua.Net == "unixpacket"
  27  	}
  28  
  29  	return true
  30  }
  31  
  32  // A Conn represents a connection to a DNS server.
  33  type Conn struct {
  34  	net.Conn                         // a net.Conn holding the connection
  35  	UDPSize        uint16            // minimum receive buffer for UDP messages
  36  	TsigSecret     map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
  37  	TsigProvider   TsigProvider      // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
  38  	tsigRequestMAC string
  39  }
  40  
  41  func (co *Conn) tsigProvider() TsigProvider {
  42  	if co.TsigProvider != nil {
  43  		return co.TsigProvider
  44  	}
  45  	// tsigSecretProvider will return ErrSecret if co.TsigSecret is nil.
  46  	return tsigSecretProvider(co.TsigSecret)
  47  }
  48  
  49  // A Client defines parameters for a DNS client.
  50  type Client struct {
  51  	Net       string      // if "tcp" or "tcp-tls" (DNS over TLS) a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
  52  	UDPSize   uint16      // minimum receive buffer for UDP messages
  53  	TLSConfig *tls.Config // TLS connection configuration
  54  	Dialer    *net.Dialer // a net.Dialer used to set local address, timeouts and more
  55  	// Timeout is a cumulative timeout for dial, write and read, defaults to 0 (disabled) - overrides DialTimeout, ReadTimeout,
  56  	// WriteTimeout when non-zero. Can be overridden with net.Dialer.Timeout (see Client.ExchangeWithDialer and
  57  	// Client.Dialer) or context.Context.Deadline (see ExchangeContext)
  58  	Timeout      time.Duration
  59  	DialTimeout  time.Duration     // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero
  60  	ReadTimeout  time.Duration     // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
  61  	WriteTimeout time.Duration     // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
  62  	TsigSecret   map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
  63  	TsigProvider TsigProvider      // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
  64  
  65  	// SingleInflight previously serialised multiple concurrent queries for the
  66  	// same Qname, Qtype and Qclass to ensure only one would be in flight at a
  67  	// time.
  68  	//
  69  	// Deprecated: This is a no-op. Callers should implement their own in flight
  70  	// query caching if needed. See github.com/miekg/dns/issues/1449.
  71  	SingleInflight bool
  72  }
  73  
  74  // Exchange performs a synchronous UDP query. It sends the message m to the address
  75  // contained in a and waits for a reply. Exchange does not retry a failed query, nor
  76  // will it fall back to TCP in case of truncation.
  77  // See client.Exchange for more information on setting larger buffer sizes.
  78  func Exchange(m *Msg, a string) (r *Msg, err error) {
  79  	client := Client{Net: "udp"}
  80  	r, _, err = client.Exchange(m, a)
  81  	return r, err
  82  }
  83  
  84  func (c *Client) dialTimeout() time.Duration {
  85  	if c.Timeout != 0 {
  86  		return c.Timeout
  87  	}
  88  	if c.DialTimeout != 0 {
  89  		return c.DialTimeout
  90  	}
  91  	return dnsTimeout
  92  }
  93  
  94  func (c *Client) readTimeout() time.Duration {
  95  	if c.ReadTimeout != 0 {
  96  		return c.ReadTimeout
  97  	}
  98  	return dnsTimeout
  99  }
 100  
 101  func (c *Client) writeTimeout() time.Duration {
 102  	if c.WriteTimeout != 0 {
 103  		return c.WriteTimeout
 104  	}
 105  	return dnsTimeout
 106  }
 107  
 108  // Dial connects to the address on the named network.
 109  func (c *Client) Dial(address string) (conn *Conn, err error) {
 110  	return c.DialContext(context.Background(), address)
 111  }
 112  
 113  // DialContext connects to the address on the named network, with a context.Context.
 114  func (c *Client) DialContext(ctx context.Context, address string) (conn *Conn, err error) {
 115  	// create a new dialer with the appropriate timeout
 116  	var d net.Dialer
 117  	if c.Dialer == nil {
 118  		d = net.Dialer{Timeout: c.getTimeoutForRequest(c.dialTimeout())}
 119  	} else {
 120  		d = *c.Dialer
 121  	}
 122  
 123  	network := c.Net
 124  	if network == "" {
 125  		network = "udp"
 126  	}
 127  
 128  	useTLS := strings.HasPrefix(network, "tcp") && strings.HasSuffix(network, "-tls")
 129  
 130  	conn = new(Conn)
 131  	if useTLS {
 132  		network = strings.TrimSuffix(network, "-tls")
 133  
 134  		tlsDialer := tls.Dialer{
 135  			NetDialer: &d,
 136  			Config:    c.TLSConfig,
 137  		}
 138  		conn.Conn, err = tlsDialer.DialContext(ctx, network, address)
 139  	} else {
 140  		conn.Conn, err = d.DialContext(ctx, network, address)
 141  	}
 142  	if err != nil {
 143  		return nil, err
 144  	}
 145  	conn.UDPSize = c.UDPSize
 146  	return conn, nil
 147  }
 148  
 149  // Exchange performs a synchronous query. It sends the message m to the address
 150  // contained in a and waits for a reply. Basic use pattern with a *dns.Client:
 151  //
 152  //	c := new(dns.Client)
 153  //	in, rtt, err := c.Exchange(message, "127.0.0.1:53")
 154  //
 155  // Exchange does not retry a failed query, nor will it fall back to TCP in
 156  // case of truncation.
 157  // It is up to the caller to create a message that allows for larger responses to be
 158  // returned. Specifically this means adding an EDNS0 OPT RR that will advertise a larger
 159  // buffer, see SetEdns0. Messages without an OPT RR will fallback to the historic limit
 160  // of 512 bytes
 161  // To specify a local address or a timeout, the caller has to set the `Client.Dialer`
 162  // attribute appropriately
 163  func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, err error) {
 164  	co, err := c.Dial(address)
 165  
 166  	if err != nil {
 167  		return nil, 0, err
 168  	}
 169  	defer co.Close()
 170  	return c.ExchangeWithConn(m, co)
 171  }
 172  
 173  // ExchangeWithConn has the same behavior as Exchange, just with a predetermined connection
 174  // that will be used instead of creating a new one.
 175  // Usage pattern with a *dns.Client:
 176  //
 177  //	c := new(dns.Client)
 178  //	// connection management logic goes here
 179  //
 180  //	conn := c.Dial(address)
 181  //	in, rtt, err := c.ExchangeWithConn(message, conn)
 182  //
 183  // This allows users of the library to implement their own connection management,
 184  // as opposed to Exchange, which will always use new connections and incur the added overhead
 185  // that entails when using "tcp" and especially "tcp-tls" clients.
 186  func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) {
 187  	return c.ExchangeWithConnContext(context.Background(), m, conn)
 188  }
 189  
 190  // ExchangeWithConnContext has the same behaviour as ExchangeWithConn and
 191  // additionally obeys deadlines from the passed Context.
 192  func (c *Client) ExchangeWithConnContext(ctx context.Context, m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) {
 193  	opt := m.IsEdns0()
 194  	// If EDNS0 is used use that for size.
 195  	if opt != nil && opt.UDPSize() >= MinMsgSize {
 196  		co.UDPSize = opt.UDPSize()
 197  	}
 198  	// Otherwise use the client's configured UDP size.
 199  	if opt == nil && c.UDPSize >= MinMsgSize {
 200  		co.UDPSize = c.UDPSize
 201  	}
 202  
 203  	// write with the appropriate write timeout
 204  	t := time.Now()
 205  	writeDeadline := t.Add(c.getTimeoutForRequest(c.writeTimeout()))
 206  	readDeadline := t.Add(c.getTimeoutForRequest(c.readTimeout()))
 207  	if deadline, ok := ctx.Deadline(); ok {
 208  		if deadline.Before(writeDeadline) {
 209  			writeDeadline = deadline
 210  		}
 211  		if deadline.Before(readDeadline) {
 212  			readDeadline = deadline
 213  		}
 214  	}
 215  	co.SetWriteDeadline(writeDeadline)
 216  	co.SetReadDeadline(readDeadline)
 217  
 218  	co.TsigSecret, co.TsigProvider = c.TsigSecret, c.TsigProvider
 219  
 220  	if err = co.WriteMsg(m); err != nil {
 221  		return nil, 0, err
 222  	}
 223  
 224  	if isPacketConn(co.Conn) {
 225  		for {
 226  			r, err = co.ReadMsg()
 227  			// Ignore replies with mismatched IDs because they might be
 228  			// responses to earlier queries that timed out.
 229  			if err != nil || r.Id == m.Id {
 230  				break
 231  			}
 232  		}
 233  	} else {
 234  		r, err = co.ReadMsg()
 235  		if err == nil && r.Id != m.Id {
 236  			err = ErrId
 237  		}
 238  	}
 239  	rtt = time.Since(t)
 240  	return r, rtt, err
 241  }
 242  
 243  // ReadMsg reads a message from the connection co.
 244  // If the received message contains a TSIG record the transaction signature
 245  // is verified. This method always tries to return the message, however if an
 246  // error is returned there are no guarantees that the returned message is a
 247  // valid representation of the packet read.
 248  func (co *Conn) ReadMsg() (*Msg, error) {
 249  	p, err := co.ReadMsgHeader(nil)
 250  	if err != nil {
 251  		return nil, err
 252  	}
 253  
 254  	m := new(Msg)
 255  	if err := m.Unpack(p); err != nil {
 256  		// If an error was returned, we still want to allow the user to use
 257  		// the message, but naively they can just check err if they don't want
 258  		// to use an erroneous message
 259  		return m, err
 260  	}
 261  	if t := m.IsTsig(); t != nil {
 262  		// Need to work on the original message p, as that was used to calculate the tsig.
 263  		err = TsigVerifyWithProvider(p, co.tsigProvider(), co.tsigRequestMAC, false)
 264  	}
 265  	return m, err
 266  }
 267  
 268  // ReadMsgHeader reads a DNS message, parses and populates hdr (when hdr is not nil).
 269  // Returns message as a byte slice to be parsed with Msg.Unpack later on.
 270  // Note that error handling on the message body is not possible as only the header is parsed.
 271  func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
 272  	var (
 273  		p   []byte
 274  		n   int
 275  		err error
 276  	)
 277  
 278  	if isPacketConn(co.Conn) {
 279  		if co.UDPSize > MinMsgSize {
 280  			p = make([]byte, co.UDPSize)
 281  		} else {
 282  			p = make([]byte, MinMsgSize)
 283  		}
 284  		n, err = co.Read(p)
 285  	} else {
 286  		var length uint16
 287  		if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil {
 288  			return nil, err
 289  		}
 290  
 291  		p = make([]byte, length)
 292  		n, err = io.ReadFull(co.Conn, p)
 293  	}
 294  
 295  	if err != nil {
 296  		return nil, err
 297  	} else if n < headerSize {
 298  		return nil, ErrShortRead
 299  	}
 300  
 301  	p = p[:n]
 302  	if hdr != nil {
 303  		dh, _, err := unpackMsgHdr(p, 0)
 304  		if err != nil {
 305  			return nil, err
 306  		}
 307  		*hdr = dh
 308  	}
 309  	return p, err
 310  }
 311  
 312  // Read implements the net.Conn read method.
 313  func (co *Conn) Read(p []byte) (n int, err error) {
 314  	if co.Conn == nil {
 315  		return 0, ErrConnEmpty
 316  	}
 317  
 318  	if isPacketConn(co.Conn) {
 319  		// UDP connection
 320  		return co.Conn.Read(p)
 321  	}
 322  
 323  	var length uint16
 324  	if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil {
 325  		return 0, err
 326  	}
 327  	if int(length) > len(p) {
 328  		return 0, io.ErrShortBuffer
 329  	}
 330  
 331  	return io.ReadFull(co.Conn, p[:length])
 332  }
 333  
 334  // WriteMsg sends a message through the connection co.
 335  // If the message m contains a TSIG record the transaction
 336  // signature is calculated.
 337  func (co *Conn) WriteMsg(m *Msg) (err error) {
 338  	var out []byte
 339  	if t := m.IsTsig(); t != nil {
 340  		// Set tsigRequestMAC for the next read, although only used in zone transfers.
 341  		out, co.tsigRequestMAC, err = TsigGenerateWithProvider(m, co.tsigProvider(), co.tsigRequestMAC, false)
 342  	} else {
 343  		out, err = m.Pack()
 344  	}
 345  	if err != nil {
 346  		return err
 347  	}
 348  	_, err = co.Write(out)
 349  	return err
 350  }
 351  
 352  // Write implements the net.Conn Write method.
 353  func (co *Conn) Write(p []byte) (int, error) {
 354  	if len(p) > MaxMsgSize {
 355  		return 0, &Error{err: "message too large"}
 356  	}
 357  
 358  	if isPacketConn(co.Conn) {
 359  		return co.Conn.Write(p)
 360  	}
 361  
 362  	msg := make([]byte, 2+len(p))
 363  	binary.BigEndian.PutUint16(msg, uint16(len(p)))
 364  	copy(msg[2:], p)
 365  	return co.Conn.Write(msg)
 366  }
 367  
 368  // Return the appropriate timeout for a specific request
 369  func (c *Client) getTimeoutForRequest(timeout time.Duration) time.Duration {
 370  	var requestTimeout time.Duration
 371  	if c.Timeout != 0 {
 372  		requestTimeout = c.Timeout
 373  	} else {
 374  		requestTimeout = timeout
 375  	}
 376  	// net.Dialer.Timeout has priority if smaller than the timeouts computed so
 377  	// far
 378  	if c.Dialer != nil && c.Dialer.Timeout != 0 {
 379  		if c.Dialer.Timeout < requestTimeout {
 380  			requestTimeout = c.Dialer.Timeout
 381  		}
 382  	}
 383  	return requestTimeout
 384  }
 385  
 386  // Dial connects to the address on the named network.
 387  func Dial(network, address string) (conn *Conn, err error) {
 388  	conn = new(Conn)
 389  	conn.Conn, err = net.Dial(network, address)
 390  	if err != nil {
 391  		return nil, err
 392  	}
 393  	return conn, nil
 394  }
 395  
 396  // ExchangeContext performs a synchronous UDP query, like Exchange. It
 397  // additionally obeys deadlines from the passed Context.
 398  func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) {
 399  	client := Client{Net: "udp"}
 400  	r, _, err = client.ExchangeContext(ctx, m, a)
 401  	// ignoring rtt to leave the original ExchangeContext API unchanged, but
 402  	// this function will go away
 403  	return r, err
 404  }
 405  
 406  // ExchangeConn performs a synchronous query. It sends the message m via the connection
 407  // c and waits for a reply. The connection c is not closed by ExchangeConn.
 408  // Deprecated: This function is going away, but can easily be mimicked:
 409  //
 410  //	co := &dns.Conn{Conn: c} // c is your net.Conn
 411  //	co.WriteMsg(m)
 412  //	in, _  := co.ReadMsg()
 413  //	co.Close()
 414  func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) {
 415  	println("dns: ExchangeConn: this function is deprecated")
 416  	co := new(Conn)
 417  	co.Conn = c
 418  	if err = co.WriteMsg(m); err != nil {
 419  		return nil, err
 420  	}
 421  	r, err = co.ReadMsg()
 422  	if err == nil && r.Id != m.Id {
 423  		err = ErrId
 424  	}
 425  	return r, err
 426  }
 427  
 428  // DialTimeout acts like Dial but takes a timeout.
 429  func DialTimeout(network, address string, timeout time.Duration) (conn *Conn, err error) {
 430  	client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}}
 431  	return client.Dial(address)
 432  }
 433  
 434  // DialWithTLS connects to the address on the named network with TLS.
 435  func DialWithTLS(network, address string, tlsConfig *tls.Config) (conn *Conn, err error) {
 436  	if !strings.HasSuffix(network, "-tls") {
 437  		network += "-tls"
 438  	}
 439  	client := Client{Net: network, TLSConfig: tlsConfig}
 440  	return client.Dial(address)
 441  }
 442  
 443  // DialTimeoutWithTLS acts like DialWithTLS but takes a timeout.
 444  func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout time.Duration) (conn *Conn, err error) {
 445  	if !strings.HasSuffix(network, "-tls") {
 446  		network += "-tls"
 447  	}
 448  	client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}, TLSConfig: tlsConfig}
 449  	return client.Dial(address)
 450  }
 451  
 452  // ExchangeContext acts like Exchange, but honors the deadline on the provided
 453  // context, if present. If there is both a context deadline and a configured
 454  // timeout on the client, the earliest of the two takes effect.
 455  func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
 456  	conn, err := c.DialContext(ctx, a)
 457  	if err != nil {
 458  		return nil, 0, err
 459  	}
 460  	defer conn.Close()
 461  
 462  	return c.ExchangeWithConnContext(ctx, m, conn)
 463  }
 464