msg.go raw

   1  // DNS packet assembly, see RFC 1035. Converting from - Unpack() -
   2  // and to - Pack() - wire format.
   3  // All the packers and unpackers take a (msg []byte, off int)
   4  // and return (off1 int, ok bool).  If they return ok==false, they
   5  // also return off1==len(msg), so that the next unpacker will
   6  // also fail.  This lets us avoid checks of ok until the end of a
   7  // packing sequence.
   8  
   9  package dns
  10  
  11  //go:generate go run msg_generate.go
  12  
  13  import (
  14  	"crypto/rand"
  15  	"encoding/binary"
  16  	"fmt"
  17  	"math/big"
  18  	"strconv"
  19  	"strings"
  20  )
  21  
  22  const (
  23  	maxCompressionOffset    = 2 << 13 // We have 14 bits for the compression pointer
  24  	maxDomainNameWireOctets = 255     // See RFC 1035 section 2.3.4
  25  
  26  	// This is the maximum number of compression pointers that should occur in a
  27  	// semantically valid message. Each label in a domain name must be at least one
  28  	// octet and is separated by a period. The root label won't be represented by a
  29  	// compression pointer to a compression pointer, hence the -2 to exclude the
  30  	// smallest valid root label.
  31  	//
  32  	// It is possible to construct a valid message that has more compression pointers
  33  	// than this, and still doesn't loop, by pointing to a previous pointer. This is
  34  	// not something a well written implementation should ever do, so we leave them
  35  	// to trip the maximum compression pointer check.
  36  	maxCompressionPointers = (maxDomainNameWireOctets+1)/2 - 2
  37  
  38  	// This is the maximum length of a domain name in presentation format. The
  39  	// maximum wire length of a domain name is 255 octets (see above), with the
  40  	// maximum label length being 63. The wire format requires one extra byte over
  41  	// the presentation format, reducing the number of octets by 1. Each label in
  42  	// the name will be separated by a single period, with each octet in the label
  43  	// expanding to at most 4 bytes (\DDD). If all other labels are of the maximum
  44  	// length, then the final label can only be 61 octets long to not exceed the
  45  	// maximum allowed wire length.
  46  	maxDomainNamePresentationLength = 61*4 + 1 + 63*4 + 1 + 63*4 + 1 + 63*4 + 1
  47  )
  48  
  49  // Errors defined in this package.
  50  var (
  51  	ErrAlg           error = &Error{err: "bad algorithm"}                  // ErrAlg indicates an error with the (DNSSEC) algorithm.
  52  	ErrAuth          error = &Error{err: "bad authentication"}             // ErrAuth indicates an error in the TSIG authentication.
  53  	ErrBuf           error = &Error{err: "buffer size too small"}          // ErrBuf indicates that the buffer used is too small for the message.
  54  	ErrConnEmpty     error = &Error{err: "conn has no connection"}         // ErrConnEmpty indicates a connection is being used before it is initialized.
  55  	ErrExtendedRcode error = &Error{err: "bad extended rcode"}             // ErrExtendedRcode ...
  56  	ErrFqdn          error = &Error{err: "domain must be fully qualified"} // ErrFqdn indicates that a domain name does not have a closing dot.
  57  	ErrId            error = &Error{err: "id mismatch"}                    // ErrId indicates there is a mismatch with the message's ID.
  58  	ErrKeyAlg        error = &Error{err: "bad key algorithm"}              // ErrKeyAlg indicates that the algorithm in the key is not valid.
  59  	ErrKey           error = &Error{err: "bad key"}
  60  	ErrKeySize       error = &Error{err: "bad key size"}
  61  	ErrLongDomain    error = &Error{err: fmt.Sprintf("domain name exceeded %d wire-format octets", maxDomainNameWireOctets)}
  62  	ErrNoSig         error = &Error{err: "no signature found"}
  63  	ErrPrivKey       error = &Error{err: "bad private key"}
  64  	ErrRcode         error = &Error{err: "bad rcode"}
  65  	ErrRdata         error = &Error{err: "bad rdata"}
  66  	ErrRRset         error = &Error{err: "bad rrset"}
  67  	ErrSecret        error = &Error{err: "no secrets defined"}
  68  	ErrShortRead     error = &Error{err: "short read"}
  69  	ErrSig           error = &Error{err: "bad signature"} // ErrSig indicates that a signature can not be cryptographically validated.
  70  	ErrSoa           error = &Error{err: "no SOA"}        // ErrSOA indicates that no SOA RR was seen when doing zone transfers.
  71  	ErrTime          error = &Error{err: "bad time"}      // ErrTime indicates a timing error in TSIG authentication.
  72  )
  73  
  74  // Id by default returns a 16-bit random number to be used as a message id. The
  75  // number is drawn from a cryptographically secure random number generator.
  76  // This being a variable the function can be reassigned to a custom function.
  77  // For instance, to make it return a static value for testing:
  78  //
  79  //	dns.Id = func() uint16 { return 3 }
  80  var Id = id
  81  
  82  // id returns a 16 bits random number to be used as a
  83  // message id. The random provided should be good enough.
  84  func id() uint16 {
  85  	var output uint16
  86  	err := binary.Read(rand.Reader, binary.BigEndian, &output)
  87  	if err != nil {
  88  		panic("dns: reading random id failed: " + err.Error())
  89  	}
  90  	return output
  91  }
  92  
  93  // MsgHdr is a a manually-unpacked version of (id, bits).
  94  type MsgHdr struct {
  95  	Id                 uint16
  96  	Response           bool
  97  	Opcode             int
  98  	Authoritative      bool
  99  	Truncated          bool
 100  	RecursionDesired   bool
 101  	RecursionAvailable bool
 102  	Zero               bool
 103  	AuthenticatedData  bool
 104  	CheckingDisabled   bool
 105  	Rcode              int
 106  }
 107  
 108  // Msg contains the layout of a DNS message.
 109  type Msg struct {
 110  	MsgHdr
 111  	Compress bool       `json:"-"` // If true, the message will be compressed when converted to wire format.
 112  	Question []Question // Holds the RR(s) of the question section.
 113  	Answer   []RR       // Holds the RR(s) of the answer section.
 114  	Ns       []RR       // Holds the RR(s) of the authority section.
 115  	Extra    []RR       // Holds the RR(s) of the additional section.
 116  }
 117  
 118  // ClassToString is a maps Classes to strings for each CLASS wire type.
 119  var ClassToString = map[uint16]string{
 120  	ClassINET:   "IN",
 121  	ClassCSNET:  "CS",
 122  	ClassCHAOS:  "CH",
 123  	ClassHESIOD: "HS",
 124  	ClassNONE:   "NONE",
 125  	ClassANY:    "ANY",
 126  }
 127  
 128  // OpcodeToString maps Opcodes to strings.
 129  var OpcodeToString = map[int]string{
 130  	OpcodeQuery:  "QUERY",
 131  	OpcodeIQuery: "IQUERY",
 132  	OpcodeStatus: "STATUS",
 133  	OpcodeNotify: "NOTIFY",
 134  	OpcodeUpdate: "UPDATE",
 135  }
 136  
 137  // RcodeToString maps Rcodes to strings.
 138  var RcodeToString = map[int]string{
 139  	RcodeSuccess:                    "NOERROR",
 140  	RcodeFormatError:                "FORMERR",
 141  	RcodeServerFailure:              "SERVFAIL",
 142  	RcodeNameError:                  "NXDOMAIN",
 143  	RcodeNotImplemented:             "NOTIMP",
 144  	RcodeRefused:                    "REFUSED",
 145  	RcodeYXDomain:                   "YXDOMAIN", // See RFC 2136
 146  	RcodeYXRrset:                    "YXRRSET",
 147  	RcodeNXRrset:                    "NXRRSET",
 148  	RcodeNotAuth:                    "NOTAUTH",
 149  	RcodeNotZone:                    "NOTZONE",
 150  	RcodeStatefulTypeNotImplemented: "DSOTYPENI",
 151  	RcodeBadSig:                     "BADSIG", // Also known as RcodeBadVers, see RFC 6891
 152  	//	RcodeBadVers:        "BADVERS",
 153  	RcodeBadKey:    "BADKEY",
 154  	RcodeBadTime:   "BADTIME",
 155  	RcodeBadMode:   "BADMODE",
 156  	RcodeBadName:   "BADNAME",
 157  	RcodeBadAlg:    "BADALG",
 158  	RcodeBadTrunc:  "BADTRUNC",
 159  	RcodeBadCookie: "BADCOOKIE",
 160  }
 161  
 162  // compressionMap is used to allow a more efficient compression map
 163  // to be used for internal packDomainName calls without changing the
 164  // signature or functionality of public API.
 165  //
 166  // In particular, map[string]uint16 uses 25% less per-entry memory
 167  // than does map[string]int.
 168  type compressionMap struct {
 169  	ext map[string]int    // external callers
 170  	int map[string]uint16 // internal callers
 171  }
 172  
 173  func (m compressionMap) valid() bool {
 174  	return m.int != nil || m.ext != nil
 175  }
 176  
 177  func (m compressionMap) insert(s string, pos int) {
 178  	if m.ext != nil {
 179  		m.ext[s] = pos
 180  	} else {
 181  		m.int[s] = uint16(pos)
 182  	}
 183  }
 184  
 185  func (m compressionMap) find(s string) (int, bool) {
 186  	if m.ext != nil {
 187  		pos, ok := m.ext[s]
 188  		return pos, ok
 189  	}
 190  
 191  	pos, ok := m.int[s]
 192  	return int(pos), ok
 193  }
 194  
 195  // Domain names are a sequence of counted strings
 196  // split at the dots. They end with a zero-length string.
 197  
 198  // PackDomainName packs a domain name s into msg[off:].
 199  // If compression is wanted compress must be true and the compression
 200  // map needs to hold a mapping between domain names and offsets
 201  // pointing into msg.
 202  func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
 203  	return packDomainName(s, msg, off, compressionMap{ext: compression}, compress)
 204  }
 205  
 206  func packDomainName(s string, msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
 207  	// XXX: A logical copy of this function exists in IsDomainName and
 208  	// should be kept in sync with this function.
 209  
 210  	ls := len(s)
 211  	if ls == 0 { // Ok, for instance when dealing with update RR without any rdata.
 212  		return off, nil
 213  	}
 214  
 215  	// If not fully qualified, error out.
 216  	if !IsFqdn(s) {
 217  		return len(msg), ErrFqdn
 218  	}
 219  
 220  	// Each dot ends a segment of the name.
 221  	// We trade each dot byte for a length byte.
 222  	// Except for escaped dots (\.), which are normal dots.
 223  	// There is also a trailing zero.
 224  
 225  	// Compression
 226  	pointer := -1
 227  
 228  	// Emit sequence of counted strings, chopping at dots.
 229  	var (
 230  		begin     int
 231  		compBegin int
 232  		compOff   int
 233  		bs        []byte
 234  		wasDot    bool
 235  	)
 236  loop:
 237  	for i := 0; i < ls; i++ {
 238  		var c byte
 239  		if bs == nil {
 240  			c = s[i]
 241  		} else {
 242  			c = bs[i]
 243  		}
 244  
 245  		switch c {
 246  		case '\\':
 247  			if off+1 > len(msg) {
 248  				return len(msg), ErrBuf
 249  			}
 250  
 251  			if bs == nil {
 252  				bs = []byte(s)
 253  			}
 254  
 255  			// check for \DDD
 256  			if isDDD(bs[i+1:]) {
 257  				bs[i] = dddToByte(bs[i+1:])
 258  				copy(bs[i+1:ls-3], bs[i+4:])
 259  				ls -= 3
 260  				compOff += 3
 261  			} else {
 262  				copy(bs[i:ls-1], bs[i+1:])
 263  				ls--
 264  				compOff++
 265  			}
 266  
 267  			wasDot = false
 268  		case '.':
 269  			if i == 0 && len(s) > 1 {
 270  				// leading dots are not legal except for the root zone
 271  				return len(msg), ErrRdata
 272  			}
 273  
 274  			if wasDot {
 275  				// two dots back to back is not legal
 276  				return len(msg), ErrRdata
 277  			}
 278  			wasDot = true
 279  
 280  			labelLen := i - begin
 281  			if labelLen >= 1<<6 { // top two bits of length must be clear
 282  				return len(msg), ErrRdata
 283  			}
 284  
 285  			// off can already (we're in a loop) be bigger than len(msg)
 286  			// this happens when a name isn't fully qualified
 287  			if off+1+labelLen > len(msg) {
 288  				return len(msg), ErrBuf
 289  			}
 290  
 291  			// Don't try to compress '.'
 292  			// We should only compress when compress is true, but we should also still pick
 293  			// up names that can be used for *future* compression(s).
 294  			if compression.valid() && !isRootLabel(s, bs, begin, ls) {
 295  				if p, ok := compression.find(s[compBegin:]); ok {
 296  					// The first hit is the longest matching dname
 297  					// keep the pointer offset we get back and store
 298  					// the offset of the current name, because that's
 299  					// where we need to insert the pointer later
 300  
 301  					// If compress is true, we're allowed to compress this dname
 302  					if compress {
 303  						pointer = p // Where to point to
 304  						break loop
 305  					}
 306  				} else if off < maxCompressionOffset {
 307  					// Only offsets smaller than maxCompressionOffset can be used.
 308  					compression.insert(s[compBegin:], off)
 309  				}
 310  			}
 311  
 312  			// The following is covered by the length check above.
 313  			msg[off] = byte(labelLen)
 314  
 315  			if bs == nil {
 316  				copy(msg[off+1:], s[begin:i])
 317  			} else {
 318  				copy(msg[off+1:], bs[begin:i])
 319  			}
 320  			off += 1 + labelLen
 321  
 322  			begin = i + 1
 323  			compBegin = begin + compOff
 324  		default:
 325  			wasDot = false
 326  		}
 327  	}
 328  
 329  	// Root label is special
 330  	if isRootLabel(s, bs, 0, ls) {
 331  		return off, nil
 332  	}
 333  
 334  	// If we did compression and we find something add the pointer here
 335  	if pointer != -1 {
 336  		// We have two bytes (14 bits) to put the pointer in
 337  		binary.BigEndian.PutUint16(msg[off:], uint16(pointer^0xC000))
 338  		return off + 2, nil
 339  	}
 340  
 341  	if off < len(msg) {
 342  		msg[off] = 0
 343  	}
 344  
 345  	return off + 1, nil
 346  }
 347  
 348  // isRootLabel returns whether s or bs, from off to end, is the root
 349  // label ".".
 350  //
 351  // If bs is nil, s will be checked, otherwise bs will be checked.
 352  func isRootLabel(s string, bs []byte, off, end int) bool {
 353  	if bs == nil {
 354  		return s[off:end] == "."
 355  	}
 356  
 357  	return end-off == 1 && bs[off] == '.'
 358  }
 359  
 360  // Unpack a domain name.
 361  // In addition to the simple sequences of counted strings above,
 362  // domain names are allowed to refer to strings elsewhere in the
 363  // packet, to avoid repeating common suffixes when returning
 364  // many entries in a single domain.  The pointers are marked
 365  // by a length byte with the top two bits set.  Ignoring those
 366  // two bits, that byte and the next give a 14 bit offset from msg[0]
 367  // where we should pick up the trail.
 368  // Note that if we jump elsewhere in the packet,
 369  // we return off1 == the offset after the first pointer we found,
 370  // which is where the next record will start.
 371  // In theory, the pointers are only allowed to jump backward.
 372  // We let them jump anywhere and stop jumping after a while.
 373  
 374  // UnpackDomainName unpacks a domain name into a string. It returns
 375  // the name, the new offset into msg and any error that occurred.
 376  //
 377  // When an error is encountered, the unpacked name will be discarded
 378  // and len(msg) will be returned as the offset.
 379  func UnpackDomainName(msg []byte, off int) (string, int, error) {
 380  	s := make([]byte, 0, maxDomainNamePresentationLength)
 381  	off1 := 0
 382  	lenmsg := len(msg)
 383  	budget := maxDomainNameWireOctets
 384  	ptr := 0 // number of pointers followed
 385  Loop:
 386  	for {
 387  		if off >= lenmsg {
 388  			return "", lenmsg, ErrBuf
 389  		}
 390  		c := int(msg[off])
 391  		off++
 392  		switch c & 0xC0 {
 393  		case 0x00:
 394  			if c == 0x00 {
 395  				// end of name
 396  				break Loop
 397  			}
 398  			// literal string
 399  			if off+c > lenmsg {
 400  				return "", lenmsg, ErrBuf
 401  			}
 402  			budget -= c + 1 // +1 for the label separator
 403  			if budget <= 0 {
 404  				return "", lenmsg, ErrLongDomain
 405  			}
 406  			for _, b := range msg[off : off+c] {
 407  				if isDomainNameLabelSpecial(b) {
 408  					s = append(s, '\\', b)
 409  				} else if b < ' ' || b > '~' {
 410  					s = append(s, escapeByte(b)...)
 411  				} else {
 412  					s = append(s, b)
 413  				}
 414  			}
 415  			s = append(s, '.')
 416  			off += c
 417  		case 0xC0:
 418  			// pointer to somewhere else in msg.
 419  			// remember location after first ptr,
 420  			// since that's how many bytes we consumed.
 421  			// also, don't follow too many pointers --
 422  			// maybe there's a loop.
 423  			if off >= lenmsg {
 424  				return "", lenmsg, ErrBuf
 425  			}
 426  			c1 := msg[off]
 427  			off++
 428  			if ptr == 0 {
 429  				off1 = off
 430  			}
 431  			if ptr++; ptr > maxCompressionPointers {
 432  				return "", lenmsg, &Error{err: "too many compression pointers"}
 433  			}
 434  			// pointer should guarantee that it advances and points forwards at least
 435  			// but the condition on previous three lines guarantees that it's
 436  			// at least loop-free
 437  			off = (c^0xC0)<<8 | int(c1)
 438  		default:
 439  			// 0x80 and 0x40 are reserved
 440  			return "", lenmsg, ErrRdata
 441  		}
 442  	}
 443  	if ptr == 0 {
 444  		off1 = off
 445  	}
 446  	if len(s) == 0 {
 447  		return ".", off1, nil
 448  	}
 449  	return string(s), off1, nil
 450  }
 451  
 452  func packTxt(txt []string, msg []byte, offset int) (int, error) {
 453  	if len(txt) == 0 {
 454  		if offset >= len(msg) {
 455  			return offset, ErrBuf
 456  		}
 457  		msg[offset] = 0
 458  		return offset, nil
 459  	}
 460  	var err error
 461  	for _, s := range txt {
 462  		offset, err = packTxtString(s, msg, offset)
 463  		if err != nil {
 464  			return offset, err
 465  		}
 466  	}
 467  	return offset, nil
 468  }
 469  
 470  func packTxtString(s string, msg []byte, offset int) (int, error) {
 471  	lenByteOffset := offset
 472  	if offset >= len(msg) || len(s) > 256*4+1 /* If all \DDD */ {
 473  		return offset, ErrBuf
 474  	}
 475  	offset++
 476  	for i := 0; i < len(s); i++ {
 477  		if len(msg) <= offset {
 478  			return offset, ErrBuf
 479  		}
 480  		if s[i] == '\\' {
 481  			i++
 482  			if i == len(s) {
 483  				break
 484  			}
 485  			// check for \DDD
 486  			if isDDD(s[i:]) {
 487  				msg[offset] = dddToByte(s[i:])
 488  				i += 2
 489  			} else {
 490  				msg[offset] = s[i]
 491  			}
 492  		} else {
 493  			msg[offset] = s[i]
 494  		}
 495  		offset++
 496  	}
 497  	l := offset - lenByteOffset - 1
 498  	if l > 255 {
 499  		return offset, &Error{err: "string exceeded 255 bytes in txt"}
 500  	}
 501  	msg[lenByteOffset] = byte(l)
 502  	return offset, nil
 503  }
 504  
 505  func packOctetString(s string, msg []byte, offset int) (int, error) {
 506  	if offset >= len(msg) || len(s) > 256*4+1 {
 507  		return offset, ErrBuf
 508  	}
 509  	for i := 0; i < len(s); i++ {
 510  		if len(msg) <= offset {
 511  			return offset, ErrBuf
 512  		}
 513  		if s[i] == '\\' {
 514  			i++
 515  			if i == len(s) {
 516  				break
 517  			}
 518  			// check for \DDD
 519  			if isDDD(s[i:]) {
 520  				msg[offset] = dddToByte(s[i:])
 521  				i += 2
 522  			} else {
 523  				msg[offset] = s[i]
 524  			}
 525  		} else {
 526  			msg[offset] = s[i]
 527  		}
 528  		offset++
 529  	}
 530  	return offset, nil
 531  }
 532  
 533  func unpackTxt(msg []byte, off0 int) (ss []string, off int, err error) {
 534  	off = off0
 535  	var s string
 536  	for off < len(msg) && err == nil {
 537  		s, off, err = unpackString(msg, off)
 538  		if err == nil {
 539  			ss = append(ss, s)
 540  		}
 541  	}
 542  	return
 543  }
 544  
 545  // Helpers for dealing with escaped bytes
 546  func isDigit(b byte) bool { return b >= '0' && b <= '9' }
 547  
 548  func isDDD[T ~[]byte | ~string](s T) bool {
 549  	return len(s) >= 3 && isDigit(s[0]) && isDigit(s[1]) && isDigit(s[2])
 550  }
 551  
 552  func dddToByte[T ~[]byte | ~string](s T) byte {
 553  	_ = s[2] // bounds check hint to compiler; see golang.org/issue/14808
 554  	return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
 555  }
 556  
 557  // Helper function for packing and unpacking
 558  func intToBytes(i *big.Int, length int) []byte {
 559  	buf := i.Bytes()
 560  	if len(buf) < length {
 561  		b := make([]byte, length)
 562  		copy(b[length-len(buf):], buf)
 563  		return b
 564  	}
 565  	return buf
 566  }
 567  
 568  // PackRR packs a resource record rr into msg[off:].
 569  // See PackDomainName for documentation about the compression.
 570  func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
 571  	headerEnd, off1, err := packRR(rr, msg, off, compressionMap{ext: compression}, compress)
 572  	if err == nil {
 573  		// packRR no longer sets the Rdlength field on the rr, but
 574  		// callers might be expecting it so we set it here.
 575  		rr.Header().Rdlength = uint16(off1 - headerEnd)
 576  	}
 577  	return off1, err
 578  }
 579  
 580  func packRR(rr RR, msg []byte, off int, compression compressionMap, compress bool) (headerEnd int, off1 int, err error) {
 581  	if rr == nil {
 582  		return len(msg), len(msg), &Error{err: "nil rr"}
 583  	}
 584  
 585  	headerEnd, err = rr.Header().packHeader(msg, off, compression, compress)
 586  	if err != nil {
 587  		return headerEnd, len(msg), err
 588  	}
 589  
 590  	off1, err = rr.pack(msg, headerEnd, compression, compress)
 591  	if err != nil {
 592  		return headerEnd, len(msg), err
 593  	}
 594  
 595  	rdlength := off1 - headerEnd
 596  	if int(uint16(rdlength)) != rdlength { // overflow
 597  		return headerEnd, len(msg), ErrRdata
 598  	}
 599  
 600  	// The RDLENGTH field is the last field in the header and we set it here.
 601  	binary.BigEndian.PutUint16(msg[headerEnd-2:], uint16(rdlength))
 602  	return headerEnd, off1, nil
 603  }
 604  
 605  // UnpackRR unpacks msg[off:] into an RR.
 606  func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
 607  	h, off, msg, err := unpackHeader(msg, off)
 608  	if err != nil {
 609  		return nil, len(msg), err
 610  	}
 611  
 612  	return UnpackRRWithHeader(h, msg, off)
 613  }
 614  
 615  // UnpackRRWithHeader unpacks the record type specific payload given an existing
 616  // RR_Header.
 617  func UnpackRRWithHeader(h RR_Header, msg []byte, off int) (rr RR, off1 int, err error) {
 618  	if newFn, ok := TypeToRR[h.Rrtype]; ok {
 619  		rr = newFn()
 620  		*rr.Header() = h
 621  	} else {
 622  		rr = &RFC3597{Hdr: h}
 623  	}
 624  
 625  	if off < 0 || off > len(msg) {
 626  		return &h, off, &Error{err: "bad off"}
 627  	}
 628  
 629  	end := off + int(h.Rdlength)
 630  	if end < off || end > len(msg) {
 631  		return &h, end, &Error{err: "bad rdlength"}
 632  	}
 633  
 634  	if noRdata(h) {
 635  		return rr, off, nil
 636  	}
 637  
 638  	off, err = rr.unpack(msg, off)
 639  	if err != nil {
 640  		return nil, end, err
 641  	}
 642  	if off != end {
 643  		return &h, end, &Error{err: "bad rdlength"}
 644  	}
 645  
 646  	return rr, off, nil
 647  }
 648  
 649  // unpackRRslice unpacks msg[off:] into an []RR.
 650  // If we cannot unpack the whole array, then it will return nil
 651  func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error) {
 652  	var r RR
 653  	// Don't pre-allocate, l may be under attacker control
 654  	var dst []RR
 655  	for i := 0; i < l; i++ {
 656  		off1 := off
 657  		r, off, err = UnpackRR(msg, off)
 658  		if err != nil {
 659  			off = len(msg)
 660  			break
 661  		}
 662  		// If offset does not increase anymore, l is a lie
 663  		if off1 == off {
 664  			break
 665  		}
 666  		dst = append(dst, r)
 667  	}
 668  	if err != nil && off == len(msg) {
 669  		dst = nil
 670  	}
 671  	return dst, off, err
 672  }
 673  
 674  // Convert a MsgHdr to a string, with dig-like headers:
 675  //
 676  // ;; opcode: QUERY, status: NOERROR, id: 48404
 677  //
 678  // ;; flags: qr aa rd ra;
 679  func (h *MsgHdr) String() string {
 680  	if h == nil {
 681  		return "<nil> MsgHdr"
 682  	}
 683  
 684  	s := ";; opcode: " + OpcodeToString[h.Opcode]
 685  	s += ", status: " + RcodeToString[h.Rcode]
 686  	s += ", id: " + strconv.Itoa(int(h.Id)) + "\n"
 687  
 688  	s += ";; flags:"
 689  	if h.Response {
 690  		s += " qr"
 691  	}
 692  	if h.Authoritative {
 693  		s += " aa"
 694  	}
 695  	if h.Truncated {
 696  		s += " tc"
 697  	}
 698  	if h.RecursionDesired {
 699  		s += " rd"
 700  	}
 701  	if h.RecursionAvailable {
 702  		s += " ra"
 703  	}
 704  	if h.Zero { // Hmm
 705  		s += " z"
 706  	}
 707  	if h.AuthenticatedData {
 708  		s += " ad"
 709  	}
 710  	if h.CheckingDisabled {
 711  		s += " cd"
 712  	}
 713  
 714  	s += ";"
 715  	return s
 716  }
 717  
 718  // Pack packs a Msg: it is converted to wire format.
 719  // If the dns.Compress is true the message will be in compressed wire format.
 720  func (dns *Msg) Pack() (msg []byte, err error) {
 721  	return dns.PackBuffer(nil)
 722  }
 723  
 724  // PackBuffer packs a Msg, using the given buffer buf. If buf is too small a new buffer is allocated.
 725  func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
 726  	// If this message can't be compressed, avoid filling the
 727  	// compression map and creating garbage.
 728  	if dns.Compress && dns.isCompressible() {
 729  		compression := make(map[string]uint16) // Compression pointer mappings.
 730  		return dns.packBufferWithCompressionMap(buf, compressionMap{int: compression}, true)
 731  	}
 732  
 733  	return dns.packBufferWithCompressionMap(buf, compressionMap{}, false)
 734  }
 735  
 736  // packBufferWithCompressionMap packs a Msg, using the given buffer buf.
 737  func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression compressionMap, compress bool) (msg []byte, err error) {
 738  	if dns.Rcode < 0 || dns.Rcode > 0xFFF {
 739  		return nil, ErrRcode
 740  	}
 741  
 742  	// Set extended rcode unconditionally if we have an opt, this will allow
 743  	// resetting the extended rcode bits if they need to.
 744  	if opt := dns.IsEdns0(); opt != nil {
 745  		opt.SetExtendedRcode(uint16(dns.Rcode))
 746  	} else if dns.Rcode > 0xF {
 747  		// If Rcode is an extended one and opt is nil, error out.
 748  		return nil, ErrExtendedRcode
 749  	}
 750  
 751  	// Convert convenient Msg into wire-like Header.
 752  	var dh Header
 753  	dh.Id = dns.Id
 754  	dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode&0xF)
 755  	if dns.Response {
 756  		dh.Bits |= _QR
 757  	}
 758  	if dns.Authoritative {
 759  		dh.Bits |= _AA
 760  	}
 761  	if dns.Truncated {
 762  		dh.Bits |= _TC
 763  	}
 764  	if dns.RecursionDesired {
 765  		dh.Bits |= _RD
 766  	}
 767  	if dns.RecursionAvailable {
 768  		dh.Bits |= _RA
 769  	}
 770  	if dns.Zero {
 771  		dh.Bits |= _Z
 772  	}
 773  	if dns.AuthenticatedData {
 774  		dh.Bits |= _AD
 775  	}
 776  	if dns.CheckingDisabled {
 777  		dh.Bits |= _CD
 778  	}
 779  
 780  	dh.Qdcount = uint16(len(dns.Question))
 781  	dh.Ancount = uint16(len(dns.Answer))
 782  	dh.Nscount = uint16(len(dns.Ns))
 783  	dh.Arcount = uint16(len(dns.Extra))
 784  
 785  	// We need the uncompressed length here, because we first pack it and then compress it.
 786  	msg = buf
 787  	uncompressedLen := msgLenWithCompressionMap(dns, nil)
 788  	if packLen := uncompressedLen + 1; len(msg) < packLen {
 789  		msg = make([]byte, packLen)
 790  	}
 791  
 792  	// Pack it in: header and then the pieces.
 793  	off := 0
 794  	off, err = dh.pack(msg, off, compression, compress)
 795  	if err != nil {
 796  		return nil, err
 797  	}
 798  	for _, r := range dns.Question {
 799  		off, err = r.pack(msg, off, compression, compress)
 800  		if err != nil {
 801  			return nil, err
 802  		}
 803  	}
 804  	for _, r := range dns.Answer {
 805  		_, off, err = packRR(r, msg, off, compression, compress)
 806  		if err != nil {
 807  			return nil, err
 808  		}
 809  	}
 810  	for _, r := range dns.Ns {
 811  		_, off, err = packRR(r, msg, off, compression, compress)
 812  		if err != nil {
 813  			return nil, err
 814  		}
 815  	}
 816  	for _, r := range dns.Extra {
 817  		_, off, err = packRR(r, msg, off, compression, compress)
 818  		if err != nil {
 819  			return nil, err
 820  		}
 821  	}
 822  	return msg[:off], nil
 823  }
 824  
 825  func (dns *Msg) unpack(dh Header, msg []byte, off int) (err error) {
 826  	// If we are at the end of the message we should return *just* the
 827  	// header. This can still be useful to the caller. 9.9.9.9 sends these
 828  	// when responding with REFUSED for instance.
 829  	if off == len(msg) {
 830  		// reset sections before returning
 831  		dns.Question, dns.Answer, dns.Ns, dns.Extra = nil, nil, nil, nil
 832  		return nil
 833  	}
 834  
 835  	// Qdcount, Ancount, Nscount, Arcount can't be trusted, as they are
 836  	// attacker controlled. This means we can't use them to pre-allocate
 837  	// slices.
 838  	dns.Question = nil
 839  	for i := 0; i < int(dh.Qdcount); i++ {
 840  		off1 := off
 841  		var q Question
 842  		q, off, err = unpackQuestion(msg, off)
 843  		if err != nil {
 844  			return err
 845  		}
 846  		if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie!
 847  			dh.Qdcount = uint16(i)
 848  			break
 849  		}
 850  		dns.Question = append(dns.Question, q)
 851  	}
 852  
 853  	dns.Answer, off, err = unpackRRslice(int(dh.Ancount), msg, off)
 854  	// The header counts might have been wrong so we need to update it
 855  	dh.Ancount = uint16(len(dns.Answer))
 856  	if err == nil {
 857  		dns.Ns, off, err = unpackRRslice(int(dh.Nscount), msg, off)
 858  	}
 859  	// The header counts might have been wrong so we need to update it
 860  	dh.Nscount = uint16(len(dns.Ns))
 861  	if err == nil {
 862  		dns.Extra, _, err = unpackRRslice(int(dh.Arcount), msg, off)
 863  	}
 864  	// The header counts might have been wrong so we need to update it
 865  	dh.Arcount = uint16(len(dns.Extra))
 866  
 867  	// Set extended Rcode
 868  	if opt := dns.IsEdns0(); opt != nil {
 869  		dns.Rcode |= opt.ExtendedRcode()
 870  	}
 871  
 872  	// TODO(miek) make this an error?
 873  	// use PackOpt to let people tell how detailed the error reporting should be?
 874  	// if off != len(msg) {
 875  	//	// println("dns: extra bytes in dns packet", off, "<", len(msg))
 876  	// }
 877  	return err
 878  }
 879  
 880  // Unpack unpacks a binary message to a Msg structure.
 881  func (dns *Msg) Unpack(msg []byte) (err error) {
 882  	dh, off, err := unpackMsgHdr(msg, 0)
 883  	if err != nil {
 884  		return err
 885  	}
 886  
 887  	dns.setHdr(dh)
 888  	return dns.unpack(dh, msg, off)
 889  }
 890  
 891  // Convert a complete message to a string with dig-like output.
 892  func (dns *Msg) String() string {
 893  	if dns == nil {
 894  		return "<nil> MsgHdr"
 895  	}
 896  	s := dns.MsgHdr.String() + " "
 897  	if dns.MsgHdr.Opcode == OpcodeUpdate {
 898  		s += "ZONE: " + strconv.Itoa(len(dns.Question)) + ", "
 899  		s += "PREREQ: " + strconv.Itoa(len(dns.Answer)) + ", "
 900  		s += "UPDATE: " + strconv.Itoa(len(dns.Ns)) + ", "
 901  		s += "ADDITIONAL: " + strconv.Itoa(len(dns.Extra)) + "\n"
 902  	} else {
 903  		s += "QUERY: " + strconv.Itoa(len(dns.Question)) + ", "
 904  		s += "ANSWER: " + strconv.Itoa(len(dns.Answer)) + ", "
 905  		s += "AUTHORITY: " + strconv.Itoa(len(dns.Ns)) + ", "
 906  		s += "ADDITIONAL: " + strconv.Itoa(len(dns.Extra)) + "\n"
 907  	}
 908  	opt := dns.IsEdns0()
 909  	if opt != nil {
 910  		// OPT PSEUDOSECTION
 911  		s += opt.String() + "\n"
 912  	}
 913  	if len(dns.Question) > 0 {
 914  		if dns.MsgHdr.Opcode == OpcodeUpdate {
 915  			s += "\n;; ZONE SECTION:\n"
 916  		} else {
 917  			s += "\n;; QUESTION SECTION:\n"
 918  		}
 919  		for _, r := range dns.Question {
 920  			s += r.String() + "\n"
 921  		}
 922  	}
 923  	if len(dns.Answer) > 0 {
 924  		if dns.MsgHdr.Opcode == OpcodeUpdate {
 925  			s += "\n;; PREREQUISITE SECTION:\n"
 926  		} else {
 927  			s += "\n;; ANSWER SECTION:\n"
 928  		}
 929  		for _, r := range dns.Answer {
 930  			if r != nil {
 931  				s += r.String() + "\n"
 932  			}
 933  		}
 934  	}
 935  	if len(dns.Ns) > 0 {
 936  		if dns.MsgHdr.Opcode == OpcodeUpdate {
 937  			s += "\n;; UPDATE SECTION:\n"
 938  		} else {
 939  			s += "\n;; AUTHORITY SECTION:\n"
 940  		}
 941  		for _, r := range dns.Ns {
 942  			if r != nil {
 943  				s += r.String() + "\n"
 944  			}
 945  		}
 946  	}
 947  	if len(dns.Extra) > 0 && (opt == nil || len(dns.Extra) > 1) {
 948  		s += "\n;; ADDITIONAL SECTION:\n"
 949  		for _, r := range dns.Extra {
 950  			if r != nil && r.Header().Rrtype != TypeOPT {
 951  				s += r.String() + "\n"
 952  			}
 953  		}
 954  	}
 955  	return s
 956  }
 957  
 958  // isCompressible returns whether the msg may be compressible.
 959  func (dns *Msg) isCompressible() bool {
 960  	// If we only have one question, there is nothing we can ever compress.
 961  	return len(dns.Question) > 1 || len(dns.Answer) > 0 ||
 962  		len(dns.Ns) > 0 || len(dns.Extra) > 0
 963  }
 964  
 965  // Len returns the message length when in (un)compressed wire format.
 966  // If dns.Compress is true compression it is taken into account. Len()
 967  // is provided to be a faster way to get the size of the resulting packet,
 968  // than packing it, measuring the size and discarding the buffer.
 969  func (dns *Msg) Len() int {
 970  	// If this message can't be compressed, avoid filling the
 971  	// compression map and creating garbage.
 972  	if dns.Compress && dns.isCompressible() {
 973  		compression := make(map[string]struct{})
 974  		return msgLenWithCompressionMap(dns, compression)
 975  	}
 976  
 977  	return msgLenWithCompressionMap(dns, nil)
 978  }
 979  
 980  func msgLenWithCompressionMap(dns *Msg, compression map[string]struct{}) int {
 981  	l := headerSize
 982  
 983  	for _, r := range dns.Question {
 984  		l += r.len(l, compression)
 985  	}
 986  	for _, r := range dns.Answer {
 987  		if r != nil {
 988  			l += r.len(l, compression)
 989  		}
 990  	}
 991  	for _, r := range dns.Ns {
 992  		if r != nil {
 993  			l += r.len(l, compression)
 994  		}
 995  	}
 996  	for _, r := range dns.Extra {
 997  		if r != nil {
 998  			l += r.len(l, compression)
 999  		}
1000  	}
1001  
1002  	return l
1003  }
1004  
1005  func domainNameLen(s string, off int, compression map[string]struct{}, compress bool) int {
1006  	if s == "" || s == "." {
1007  		return 1
1008  	}
1009  
1010  	escaped := strings.Contains(s, "\\")
1011  
1012  	if compression != nil && (compress || off < maxCompressionOffset) {
1013  		// compressionLenSearch will insert the entry into the compression
1014  		// map if it doesn't contain it.
1015  		if l, ok := compressionLenSearch(compression, s, off); ok && compress {
1016  			if escaped {
1017  				return escapedNameLen(s[:l]) + 2
1018  			}
1019  
1020  			return l + 2
1021  		}
1022  	}
1023  
1024  	if escaped {
1025  		return escapedNameLen(s) + 1
1026  	}
1027  
1028  	return len(s) + 1
1029  }
1030  
1031  func escapedNameLen(s string) int {
1032  	nameLen := len(s)
1033  	for i := 0; i < len(s); i++ {
1034  		if s[i] != '\\' {
1035  			continue
1036  		}
1037  
1038  		if isDDD(s[i+1:]) {
1039  			nameLen -= 3
1040  			i += 3
1041  		} else {
1042  			nameLen--
1043  			i++
1044  		}
1045  	}
1046  
1047  	return nameLen
1048  }
1049  
1050  func compressionLenSearch(c map[string]struct{}, s string, msgOff int) (int, bool) {
1051  	for off, end := 0, false; !end; off, end = NextLabel(s, off) {
1052  		if _, ok := c[s[off:]]; ok {
1053  			return off, true
1054  		}
1055  
1056  		if msgOff+off < maxCompressionOffset {
1057  			c[s[off:]] = struct{}{}
1058  		}
1059  	}
1060  
1061  	return 0, false
1062  }
1063  
1064  // Copy returns a new RR which is a deep-copy of r.
1065  func Copy(r RR) RR { return r.copy() }
1066  
1067  // Len returns the length (in octets) of the uncompressed RR in wire format.
1068  func Len(r RR) int { return r.len(0, nil) }
1069  
1070  // Copy returns a new *Msg which is a deep-copy of dns.
1071  func (dns *Msg) Copy() *Msg { return dns.CopyTo(new(Msg)) }
1072  
1073  // CopyTo copies the contents to the provided message using a deep-copy and returns the copy.
1074  func (dns *Msg) CopyTo(r1 *Msg) *Msg {
1075  	r1.MsgHdr = dns.MsgHdr
1076  	r1.Compress = dns.Compress
1077  
1078  	if len(dns.Question) > 0 {
1079  		// TODO(miek): Question is an immutable value, ok to do a shallow-copy
1080  		r1.Question = cloneSlice(dns.Question)
1081  	}
1082  
1083  	rrArr := make([]RR, len(dns.Answer)+len(dns.Ns)+len(dns.Extra))
1084  	r1.Answer, rrArr = rrArr[:0:len(dns.Answer)], rrArr[len(dns.Answer):]
1085  	r1.Ns, rrArr = rrArr[:0:len(dns.Ns)], rrArr[len(dns.Ns):]
1086  	r1.Extra = rrArr[:0:len(dns.Extra)]
1087  
1088  	for _, r := range dns.Answer {
1089  		r1.Answer = append(r1.Answer, r.copy())
1090  	}
1091  
1092  	for _, r := range dns.Ns {
1093  		r1.Ns = append(r1.Ns, r.copy())
1094  	}
1095  
1096  	for _, r := range dns.Extra {
1097  		r1.Extra = append(r1.Extra, r.copy())
1098  	}
1099  
1100  	return r1
1101  }
1102  
1103  func (q *Question) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
1104  	off, err := packDomainName(q.Name, msg, off, compression, compress)
1105  	if err != nil {
1106  		return off, err
1107  	}
1108  	off, err = packUint16(q.Qtype, msg, off)
1109  	if err != nil {
1110  		return off, err
1111  	}
1112  	off, err = packUint16(q.Qclass, msg, off)
1113  	if err != nil {
1114  		return off, err
1115  	}
1116  	return off, nil
1117  }
1118  
1119  func unpackQuestion(msg []byte, off int) (Question, int, error) {
1120  	var (
1121  		q   Question
1122  		err error
1123  	)
1124  	q.Name, off, err = UnpackDomainName(msg, off)
1125  	if err != nil {
1126  		return q, off, fmt.Errorf("bad question name: %w", err)
1127  	}
1128  	if off == len(msg) {
1129  		return q, off, nil
1130  	}
1131  	q.Qtype, off, err = unpackUint16(msg, off)
1132  	if err != nil {
1133  		return q, off, fmt.Errorf("bad question qtype: %w", err)
1134  	}
1135  	if off == len(msg) {
1136  		return q, off, nil
1137  	}
1138  	q.Qclass, off, err = unpackUint16(msg, off)
1139  	if err != nil {
1140  		return q, off, fmt.Errorf("bad question qclass: %w", err)
1141  	}
1142  
1143  	if off == len(msg) {
1144  		return q, off, nil
1145  	}
1146  
1147  	return q, off, nil
1148  }
1149  
1150  func (dh *Header) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
1151  	off, err := packUint16(dh.Id, msg, off)
1152  	if err != nil {
1153  		return off, err
1154  	}
1155  	off, err = packUint16(dh.Bits, msg, off)
1156  	if err != nil {
1157  		return off, err
1158  	}
1159  	off, err = packUint16(dh.Qdcount, msg, off)
1160  	if err != nil {
1161  		return off, err
1162  	}
1163  	off, err = packUint16(dh.Ancount, msg, off)
1164  	if err != nil {
1165  		return off, err
1166  	}
1167  	off, err = packUint16(dh.Nscount, msg, off)
1168  	if err != nil {
1169  		return off, err
1170  	}
1171  	off, err = packUint16(dh.Arcount, msg, off)
1172  	if err != nil {
1173  		return off, err
1174  	}
1175  	return off, nil
1176  }
1177  
1178  func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
1179  	var (
1180  		dh  Header
1181  		err error
1182  	)
1183  	dh.Id, off, err = unpackUint16(msg, off)
1184  	if err != nil {
1185  		return dh, off, fmt.Errorf("bad header id: %w", err)
1186  	}
1187  	dh.Bits, off, err = unpackUint16(msg, off)
1188  	if err != nil {
1189  		return dh, off, fmt.Errorf("bad header bits: %w", err)
1190  	}
1191  	dh.Qdcount, off, err = unpackUint16(msg, off)
1192  	if err != nil {
1193  		return dh, off, fmt.Errorf("bad header question count: %w", err)
1194  	}
1195  	dh.Ancount, off, err = unpackUint16(msg, off)
1196  	if err != nil {
1197  		return dh, off, fmt.Errorf("bad header answer count: %w", err)
1198  	}
1199  	dh.Nscount, off, err = unpackUint16(msg, off)
1200  	if err != nil {
1201  		return dh, off, fmt.Errorf("bad header ns count: %w", err)
1202  	}
1203  	dh.Arcount, off, err = unpackUint16(msg, off)
1204  	if err != nil {
1205  		return dh, off, fmt.Errorf("bad header extra count: %w", err)
1206  	}
1207  	return dh, off, nil
1208  }
1209  
1210  // setHdr set the header in the dns using the binary data in dh.
1211  func (dns *Msg) setHdr(dh Header) {
1212  	dns.Id = dh.Id
1213  	dns.Response = dh.Bits&_QR != 0
1214  	dns.Opcode = int(dh.Bits>>11) & 0xF
1215  	dns.Authoritative = dh.Bits&_AA != 0
1216  	dns.Truncated = dh.Bits&_TC != 0
1217  	dns.RecursionDesired = dh.Bits&_RD != 0
1218  	dns.RecursionAvailable = dh.Bits&_RA != 0
1219  	dns.Zero = dh.Bits&_Z != 0 // _Z covers the zero bit, which should be zero; not sure why we set it to the opposite.
1220  	dns.AuthenticatedData = dh.Bits&_AD != 0
1221  	dns.CheckingDisabled = dh.Bits&_CD != 0
1222  	dns.Rcode = int(dh.Bits & 0xF)
1223  }
1224