1 // DNS packet assembly, see RFC 1035. Converting from - Unpack() -
2 // and to - Pack() - wire format.
3 // All the packers and unpackers take a (msg []byte, off int)
4 // and return (off1 int, ok bool). If they return ok==false, they
5 // also return off1==len(msg), so that the next unpacker will
6 // also fail. This lets us avoid checks of ok until the end of a
7 // packing sequence.
8 9 package dns
10 11 //go:generate go run msg_generate.go
12 13 import (
14 "crypto/rand"
15 "encoding/binary"
16 "fmt"
17 "math/big"
18 "strconv"
19 "strings"
20 )
21 22 const (
23 maxCompressionOffset = 2 << 13 // We have 14 bits for the compression pointer
24 maxDomainNameWireOctets = 255 // See RFC 1035 section 2.3.4
25 26 // This is the maximum number of compression pointers that should occur in a
27 // semantically valid message. Each label in a domain name must be at least one
28 // octet and is separated by a period. The root label won't be represented by a
29 // compression pointer to a compression pointer, hence the -2 to exclude the
30 // smallest valid root label.
31 //
32 // It is possible to construct a valid message that has more compression pointers
33 // than this, and still doesn't loop, by pointing to a previous pointer. This is
34 // not something a well written implementation should ever do, so we leave them
35 // to trip the maximum compression pointer check.
36 maxCompressionPointers = (maxDomainNameWireOctets+1)/2 - 2
37 38 // This is the maximum length of a domain name in presentation format. The
39 // maximum wire length of a domain name is 255 octets (see above), with the
40 // maximum label length being 63. The wire format requires one extra byte over
41 // the presentation format, reducing the number of octets by 1. Each label in
42 // the name will be separated by a single period, with each octet in the label
43 // expanding to at most 4 bytes (\DDD). If all other labels are of the maximum
44 // length, then the final label can only be 61 octets long to not exceed the
45 // maximum allowed wire length.
46 maxDomainNamePresentationLength = 61*4 + 1 + 63*4 + 1 + 63*4 + 1 + 63*4 + 1
47 )
48 49 // Errors defined in this package.
50 var (
51 ErrAlg error = &Error{err: "bad algorithm"} // ErrAlg indicates an error with the (DNSSEC) algorithm.
52 ErrAuth error = &Error{err: "bad authentication"} // ErrAuth indicates an error in the TSIG authentication.
53 ErrBuf error = &Error{err: "buffer size too small"} // ErrBuf indicates that the buffer used is too small for the message.
54 ErrConnEmpty error = &Error{err: "conn has no connection"} // ErrConnEmpty indicates a connection is being used before it is initialized.
55 ErrExtendedRcode error = &Error{err: "bad extended rcode"} // ErrExtendedRcode ...
56 ErrFqdn error = &Error{err: "domain must be fully qualified"} // ErrFqdn indicates that a domain name does not have a closing dot.
57 ErrId error = &Error{err: "id mismatch"} // ErrId indicates there is a mismatch with the message's ID.
58 ErrKeyAlg error = &Error{err: "bad key algorithm"} // ErrKeyAlg indicates that the algorithm in the key is not valid.
59 ErrKey error = &Error{err: "bad key"}
60 ErrKeySize error = &Error{err: "bad key size"}
61 ErrLongDomain error = &Error{err: fmt.Sprintf("domain name exceeded %d wire-format octets", maxDomainNameWireOctets)}
62 ErrNoSig error = &Error{err: "no signature found"}
63 ErrPrivKey error = &Error{err: "bad private key"}
64 ErrRcode error = &Error{err: "bad rcode"}
65 ErrRdata error = &Error{err: "bad rdata"}
66 ErrRRset error = &Error{err: "bad rrset"}
67 ErrSecret error = &Error{err: "no secrets defined"}
68 ErrShortRead error = &Error{err: "short read"}
69 ErrSig error = &Error{err: "bad signature"} // ErrSig indicates that a signature can not be cryptographically validated.
70 ErrSoa error = &Error{err: "no SOA"} // ErrSOA indicates that no SOA RR was seen when doing zone transfers.
71 ErrTime error = &Error{err: "bad time"} // ErrTime indicates a timing error in TSIG authentication.
72 )
73 74 // Id by default returns a 16-bit random number to be used as a message id. The
75 // number is drawn from a cryptographically secure random number generator.
76 // This being a variable the function can be reassigned to a custom function.
77 // For instance, to make it return a static value for testing:
78 //
79 // dns.Id = func() uint16 { return 3 }
80 var Id = id
81 82 // id returns a 16 bits random number to be used as a
83 // message id. The random provided should be good enough.
84 func id() uint16 {
85 var output uint16
86 err := binary.Read(rand.Reader, binary.BigEndian, &output)
87 if err != nil {
88 panic("dns: reading random id failed: " + err.Error())
89 }
90 return output
91 }
92 93 // MsgHdr is a a manually-unpacked version of (id, bits).
94 type MsgHdr struct {
95 Id uint16
96 Response bool
97 Opcode int
98 Authoritative bool
99 Truncated bool
100 RecursionDesired bool
101 RecursionAvailable bool
102 Zero bool
103 AuthenticatedData bool
104 CheckingDisabled bool
105 Rcode int
106 }
107 108 // Msg contains the layout of a DNS message.
109 type Msg struct {
110 MsgHdr
111 Compress bool `json:"-"` // If true, the message will be compressed when converted to wire format.
112 Question []Question // Holds the RR(s) of the question section.
113 Answer []RR // Holds the RR(s) of the answer section.
114 Ns []RR // Holds the RR(s) of the authority section.
115 Extra []RR // Holds the RR(s) of the additional section.
116 }
117 118 // ClassToString is a maps Classes to strings for each CLASS wire type.
119 var ClassToString = map[uint16]string{
120 ClassINET: "IN",
121 ClassCSNET: "CS",
122 ClassCHAOS: "CH",
123 ClassHESIOD: "HS",
124 ClassNONE: "NONE",
125 ClassANY: "ANY",
126 }
127 128 // OpcodeToString maps Opcodes to strings.
129 var OpcodeToString = map[int]string{
130 OpcodeQuery: "QUERY",
131 OpcodeIQuery: "IQUERY",
132 OpcodeStatus: "STATUS",
133 OpcodeNotify: "NOTIFY",
134 OpcodeUpdate: "UPDATE",
135 }
136 137 // RcodeToString maps Rcodes to strings.
138 var RcodeToString = map[int]string{
139 RcodeSuccess: "NOERROR",
140 RcodeFormatError: "FORMERR",
141 RcodeServerFailure: "SERVFAIL",
142 RcodeNameError: "NXDOMAIN",
143 RcodeNotImplemented: "NOTIMP",
144 RcodeRefused: "REFUSED",
145 RcodeYXDomain: "YXDOMAIN", // See RFC 2136
146 RcodeYXRrset: "YXRRSET",
147 RcodeNXRrset: "NXRRSET",
148 RcodeNotAuth: "NOTAUTH",
149 RcodeNotZone: "NOTZONE",
150 RcodeStatefulTypeNotImplemented: "DSOTYPENI",
151 RcodeBadSig: "BADSIG", // Also known as RcodeBadVers, see RFC 6891
152 // RcodeBadVers: "BADVERS",
153 RcodeBadKey: "BADKEY",
154 RcodeBadTime: "BADTIME",
155 RcodeBadMode: "BADMODE",
156 RcodeBadName: "BADNAME",
157 RcodeBadAlg: "BADALG",
158 RcodeBadTrunc: "BADTRUNC",
159 RcodeBadCookie: "BADCOOKIE",
160 }
161 162 // compressionMap is used to allow a more efficient compression map
163 // to be used for internal packDomainName calls without changing the
164 // signature or functionality of public API.
165 //
166 // In particular, map[string]uint16 uses 25% less per-entry memory
167 // than does map[string]int.
168 type compressionMap struct {
169 ext map[string]int // external callers
170 int map[string]uint16 // internal callers
171 }
172 173 func (m compressionMap) valid() bool {
174 return m.int != nil || m.ext != nil
175 }
176 177 func (m compressionMap) insert(s string, pos int) {
178 if m.ext != nil {
179 m.ext[s] = pos
180 } else {
181 m.int[s] = uint16(pos)
182 }
183 }
184 185 func (m compressionMap) find(s string) (int, bool) {
186 if m.ext != nil {
187 pos, ok := m.ext[s]
188 return pos, ok
189 }
190 191 pos, ok := m.int[s]
192 return int(pos), ok
193 }
194 195 // Domain names are a sequence of counted strings
196 // split at the dots. They end with a zero-length string.
197 198 // PackDomainName packs a domain name s into msg[off:].
199 // If compression is wanted compress must be true and the compression
200 // map needs to hold a mapping between domain names and offsets
201 // pointing into msg.
202 func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
203 return packDomainName(s, msg, off, compressionMap{ext: compression}, compress)
204 }
205 206 func packDomainName(s string, msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
207 // XXX: A logical copy of this function exists in IsDomainName and
208 // should be kept in sync with this function.
209 210 ls := len(s)
211 if ls == 0 { // Ok, for instance when dealing with update RR without any rdata.
212 return off, nil
213 }
214 215 // If not fully qualified, error out.
216 if !IsFqdn(s) {
217 return len(msg), ErrFqdn
218 }
219 220 // Each dot ends a segment of the name.
221 // We trade each dot byte for a length byte.
222 // Except for escaped dots (\.), which are normal dots.
223 // There is also a trailing zero.
224 225 // Compression
226 pointer := -1
227 228 // Emit sequence of counted strings, chopping at dots.
229 var (
230 begin int
231 compBegin int
232 compOff int
233 bs []byte
234 wasDot bool
235 )
236 loop:
237 for i := 0; i < ls; i++ {
238 var c byte
239 if bs == nil {
240 c = s[i]
241 } else {
242 c = bs[i]
243 }
244 245 switch c {
246 case '\\':
247 if off+1 > len(msg) {
248 return len(msg), ErrBuf
249 }
250 251 if bs == nil {
252 bs = []byte(s)
253 }
254 255 // check for \DDD
256 if isDDD(bs[i+1:]) {
257 bs[i] = dddToByte(bs[i+1:])
258 copy(bs[i+1:ls-3], bs[i+4:])
259 ls -= 3
260 compOff += 3
261 } else {
262 copy(bs[i:ls-1], bs[i+1:])
263 ls--
264 compOff++
265 }
266 267 wasDot = false
268 case '.':
269 if i == 0 && len(s) > 1 {
270 // leading dots are not legal except for the root zone
271 return len(msg), ErrRdata
272 }
273 274 if wasDot {
275 // two dots back to back is not legal
276 return len(msg), ErrRdata
277 }
278 wasDot = true
279 280 labelLen := i - begin
281 if labelLen >= 1<<6 { // top two bits of length must be clear
282 return len(msg), ErrRdata
283 }
284 285 // off can already (we're in a loop) be bigger than len(msg)
286 // this happens when a name isn't fully qualified
287 if off+1+labelLen > len(msg) {
288 return len(msg), ErrBuf
289 }
290 291 // Don't try to compress '.'
292 // We should only compress when compress is true, but we should also still pick
293 // up names that can be used for *future* compression(s).
294 if compression.valid() && !isRootLabel(s, bs, begin, ls) {
295 if p, ok := compression.find(s[compBegin:]); ok {
296 // The first hit is the longest matching dname
297 // keep the pointer offset we get back and store
298 // the offset of the current name, because that's
299 // where we need to insert the pointer later
300 301 // If compress is true, we're allowed to compress this dname
302 if compress {
303 pointer = p // Where to point to
304 break loop
305 }
306 } else if off < maxCompressionOffset {
307 // Only offsets smaller than maxCompressionOffset can be used.
308 compression.insert(s[compBegin:], off)
309 }
310 }
311 312 // The following is covered by the length check above.
313 msg[off] = byte(labelLen)
314 315 if bs == nil {
316 copy(msg[off+1:], s[begin:i])
317 } else {
318 copy(msg[off+1:], bs[begin:i])
319 }
320 off += 1 + labelLen
321 322 begin = i + 1
323 compBegin = begin + compOff
324 default:
325 wasDot = false
326 }
327 }
328 329 // Root label is special
330 if isRootLabel(s, bs, 0, ls) {
331 return off, nil
332 }
333 334 // If we did compression and we find something add the pointer here
335 if pointer != -1 {
336 // We have two bytes (14 bits) to put the pointer in
337 binary.BigEndian.PutUint16(msg[off:], uint16(pointer^0xC000))
338 return off + 2, nil
339 }
340 341 if off < len(msg) {
342 msg[off] = 0
343 }
344 345 return off + 1, nil
346 }
347 348 // isRootLabel returns whether s or bs, from off to end, is the root
349 // label ".".
350 //
351 // If bs is nil, s will be checked, otherwise bs will be checked.
352 func isRootLabel(s string, bs []byte, off, end int) bool {
353 if bs == nil {
354 return s[off:end] == "."
355 }
356 357 return end-off == 1 && bs[off] == '.'
358 }
359 360 // Unpack a domain name.
361 // In addition to the simple sequences of counted strings above,
362 // domain names are allowed to refer to strings elsewhere in the
363 // packet, to avoid repeating common suffixes when returning
364 // many entries in a single domain. The pointers are marked
365 // by a length byte with the top two bits set. Ignoring those
366 // two bits, that byte and the next give a 14 bit offset from msg[0]
367 // where we should pick up the trail.
368 // Note that if we jump elsewhere in the packet,
369 // we return off1 == the offset after the first pointer we found,
370 // which is where the next record will start.
371 // In theory, the pointers are only allowed to jump backward.
372 // We let them jump anywhere and stop jumping after a while.
373 374 // UnpackDomainName unpacks a domain name into a string. It returns
375 // the name, the new offset into msg and any error that occurred.
376 //
377 // When an error is encountered, the unpacked name will be discarded
378 // and len(msg) will be returned as the offset.
379 func UnpackDomainName(msg []byte, off int) (string, int, error) {
380 s := make([]byte, 0, maxDomainNamePresentationLength)
381 off1 := 0
382 lenmsg := len(msg)
383 budget := maxDomainNameWireOctets
384 ptr := 0 // number of pointers followed
385 Loop:
386 for {
387 if off >= lenmsg {
388 return "", lenmsg, ErrBuf
389 }
390 c := int(msg[off])
391 off++
392 switch c & 0xC0 {
393 case 0x00:
394 if c == 0x00 {
395 // end of name
396 break Loop
397 }
398 // literal string
399 if off+c > lenmsg {
400 return "", lenmsg, ErrBuf
401 }
402 budget -= c + 1 // +1 for the label separator
403 if budget <= 0 {
404 return "", lenmsg, ErrLongDomain
405 }
406 for _, b := range msg[off : off+c] {
407 if isDomainNameLabelSpecial(b) {
408 s = append(s, '\\', b)
409 } else if b < ' ' || b > '~' {
410 s = append(s, escapeByte(b)...)
411 } else {
412 s = append(s, b)
413 }
414 }
415 s = append(s, '.')
416 off += c
417 case 0xC0:
418 // pointer to somewhere else in msg.
419 // remember location after first ptr,
420 // since that's how many bytes we consumed.
421 // also, don't follow too many pointers --
422 // maybe there's a loop.
423 if off >= lenmsg {
424 return "", lenmsg, ErrBuf
425 }
426 c1 := msg[off]
427 off++
428 if ptr == 0 {
429 off1 = off
430 }
431 if ptr++; ptr > maxCompressionPointers {
432 return "", lenmsg, &Error{err: "too many compression pointers"}
433 }
434 // pointer should guarantee that it advances and points forwards at least
435 // but the condition on previous three lines guarantees that it's
436 // at least loop-free
437 off = (c^0xC0)<<8 | int(c1)
438 default:
439 // 0x80 and 0x40 are reserved
440 return "", lenmsg, ErrRdata
441 }
442 }
443 if ptr == 0 {
444 off1 = off
445 }
446 if len(s) == 0 {
447 return ".", off1, nil
448 }
449 return string(s), off1, nil
450 }
451 452 func packTxt(txt []string, msg []byte, offset int) (int, error) {
453 if len(txt) == 0 {
454 if offset >= len(msg) {
455 return offset, ErrBuf
456 }
457 msg[offset] = 0
458 return offset, nil
459 }
460 var err error
461 for _, s := range txt {
462 offset, err = packTxtString(s, msg, offset)
463 if err != nil {
464 return offset, err
465 }
466 }
467 return offset, nil
468 }
469 470 func packTxtString(s string, msg []byte, offset int) (int, error) {
471 lenByteOffset := offset
472 if offset >= len(msg) || len(s) > 256*4+1 /* If all \DDD */ {
473 return offset, ErrBuf
474 }
475 offset++
476 for i := 0; i < len(s); i++ {
477 if len(msg) <= offset {
478 return offset, ErrBuf
479 }
480 if s[i] == '\\' {
481 i++
482 if i == len(s) {
483 break
484 }
485 // check for \DDD
486 if isDDD(s[i:]) {
487 msg[offset] = dddToByte(s[i:])
488 i += 2
489 } else {
490 msg[offset] = s[i]
491 }
492 } else {
493 msg[offset] = s[i]
494 }
495 offset++
496 }
497 l := offset - lenByteOffset - 1
498 if l > 255 {
499 return offset, &Error{err: "string exceeded 255 bytes in txt"}
500 }
501 msg[lenByteOffset] = byte(l)
502 return offset, nil
503 }
504 505 func packOctetString(s string, msg []byte, offset int) (int, error) {
506 if offset >= len(msg) || len(s) > 256*4+1 {
507 return offset, ErrBuf
508 }
509 for i := 0; i < len(s); i++ {
510 if len(msg) <= offset {
511 return offset, ErrBuf
512 }
513 if s[i] == '\\' {
514 i++
515 if i == len(s) {
516 break
517 }
518 // check for \DDD
519 if isDDD(s[i:]) {
520 msg[offset] = dddToByte(s[i:])
521 i += 2
522 } else {
523 msg[offset] = s[i]
524 }
525 } else {
526 msg[offset] = s[i]
527 }
528 offset++
529 }
530 return offset, nil
531 }
532 533 func unpackTxt(msg []byte, off0 int) (ss []string, off int, err error) {
534 off = off0
535 var s string
536 for off < len(msg) && err == nil {
537 s, off, err = unpackString(msg, off)
538 if err == nil {
539 ss = append(ss, s)
540 }
541 }
542 return
543 }
544 545 // Helpers for dealing with escaped bytes
546 func isDigit(b byte) bool { return b >= '0' && b <= '9' }
547 548 func isDDD[T ~[]byte | ~string](s T) bool {
549 return len(s) >= 3 && isDigit(s[0]) && isDigit(s[1]) && isDigit(s[2])
550 }
551 552 func dddToByte[T ~[]byte | ~string](s T) byte {
553 _ = s[2] // bounds check hint to compiler; see golang.org/issue/14808
554 return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
555 }
556 557 // Helper function for packing and unpacking
558 func intToBytes(i *big.Int, length int) []byte {
559 buf := i.Bytes()
560 if len(buf) < length {
561 b := make([]byte, length)
562 copy(b[length-len(buf):], buf)
563 return b
564 }
565 return buf
566 }
567 568 // PackRR packs a resource record rr into msg[off:].
569 // See PackDomainName for documentation about the compression.
570 func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
571 headerEnd, off1, err := packRR(rr, msg, off, compressionMap{ext: compression}, compress)
572 if err == nil {
573 // packRR no longer sets the Rdlength field on the rr, but
574 // callers might be expecting it so we set it here.
575 rr.Header().Rdlength = uint16(off1 - headerEnd)
576 }
577 return off1, err
578 }
579 580 func packRR(rr RR, msg []byte, off int, compression compressionMap, compress bool) (headerEnd int, off1 int, err error) {
581 if rr == nil {
582 return len(msg), len(msg), &Error{err: "nil rr"}
583 }
584 585 headerEnd, err = rr.Header().packHeader(msg, off, compression, compress)
586 if err != nil {
587 return headerEnd, len(msg), err
588 }
589 590 off1, err = rr.pack(msg, headerEnd, compression, compress)
591 if err != nil {
592 return headerEnd, len(msg), err
593 }
594 595 rdlength := off1 - headerEnd
596 if int(uint16(rdlength)) != rdlength { // overflow
597 return headerEnd, len(msg), ErrRdata
598 }
599 600 // The RDLENGTH field is the last field in the header and we set it here.
601 binary.BigEndian.PutUint16(msg[headerEnd-2:], uint16(rdlength))
602 return headerEnd, off1, nil
603 }
604 605 // UnpackRR unpacks msg[off:] into an RR.
606 func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
607 h, off, msg, err := unpackHeader(msg, off)
608 if err != nil {
609 return nil, len(msg), err
610 }
611 612 return UnpackRRWithHeader(h, msg, off)
613 }
614 615 // UnpackRRWithHeader unpacks the record type specific payload given an existing
616 // RR_Header.
617 func UnpackRRWithHeader(h RR_Header, msg []byte, off int) (rr RR, off1 int, err error) {
618 if newFn, ok := TypeToRR[h.Rrtype]; ok {
619 rr = newFn()
620 *rr.Header() = h
621 } else {
622 rr = &RFC3597{Hdr: h}
623 }
624 625 if off < 0 || off > len(msg) {
626 return &h, off, &Error{err: "bad off"}
627 }
628 629 end := off + int(h.Rdlength)
630 if end < off || end > len(msg) {
631 return &h, end, &Error{err: "bad rdlength"}
632 }
633 634 if noRdata(h) {
635 return rr, off, nil
636 }
637 638 off, err = rr.unpack(msg, off)
639 if err != nil {
640 return nil, end, err
641 }
642 if off != end {
643 return &h, end, &Error{err: "bad rdlength"}
644 }
645 646 return rr, off, nil
647 }
648 649 // unpackRRslice unpacks msg[off:] into an []RR.
650 // If we cannot unpack the whole array, then it will return nil
651 func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error) {
652 var r RR
653 // Don't pre-allocate, l may be under attacker control
654 var dst []RR
655 for i := 0; i < l; i++ {
656 off1 := off
657 r, off, err = UnpackRR(msg, off)
658 if err != nil {
659 off = len(msg)
660 break
661 }
662 // If offset does not increase anymore, l is a lie
663 if off1 == off {
664 break
665 }
666 dst = append(dst, r)
667 }
668 if err != nil && off == len(msg) {
669 dst = nil
670 }
671 return dst, off, err
672 }
673 674 // Convert a MsgHdr to a string, with dig-like headers:
675 //
676 // ;; opcode: QUERY, status: NOERROR, id: 48404
677 //
678 // ;; flags: qr aa rd ra;
679 func (h *MsgHdr) String() string {
680 if h == nil {
681 return "<nil> MsgHdr"
682 }
683 684 s := ";; opcode: " + OpcodeToString[h.Opcode]
685 s += ", status: " + RcodeToString[h.Rcode]
686 s += ", id: " + strconv.Itoa(int(h.Id)) + "\n"
687 688 s += ";; flags:"
689 if h.Response {
690 s += " qr"
691 }
692 if h.Authoritative {
693 s += " aa"
694 }
695 if h.Truncated {
696 s += " tc"
697 }
698 if h.RecursionDesired {
699 s += " rd"
700 }
701 if h.RecursionAvailable {
702 s += " ra"
703 }
704 if h.Zero { // Hmm
705 s += " z"
706 }
707 if h.AuthenticatedData {
708 s += " ad"
709 }
710 if h.CheckingDisabled {
711 s += " cd"
712 }
713 714 s += ";"
715 return s
716 }
717 718 // Pack packs a Msg: it is converted to wire format.
719 // If the dns.Compress is true the message will be in compressed wire format.
720 func (dns *Msg) Pack() (msg []byte, err error) {
721 return dns.PackBuffer(nil)
722 }
723 724 // PackBuffer packs a Msg, using the given buffer buf. If buf is too small a new buffer is allocated.
725 func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
726 // If this message can't be compressed, avoid filling the
727 // compression map and creating garbage.
728 if dns.Compress && dns.isCompressible() {
729 compression := make(map[string]uint16) // Compression pointer mappings.
730 return dns.packBufferWithCompressionMap(buf, compressionMap{int: compression}, true)
731 }
732 733 return dns.packBufferWithCompressionMap(buf, compressionMap{}, false)
734 }
735 736 // packBufferWithCompressionMap packs a Msg, using the given buffer buf.
737 func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression compressionMap, compress bool) (msg []byte, err error) {
738 if dns.Rcode < 0 || dns.Rcode > 0xFFF {
739 return nil, ErrRcode
740 }
741 742 // Set extended rcode unconditionally if we have an opt, this will allow
743 // resetting the extended rcode bits if they need to.
744 if opt := dns.IsEdns0(); opt != nil {
745 opt.SetExtendedRcode(uint16(dns.Rcode))
746 } else if dns.Rcode > 0xF {
747 // If Rcode is an extended one and opt is nil, error out.
748 return nil, ErrExtendedRcode
749 }
750 751 // Convert convenient Msg into wire-like Header.
752 var dh Header
753 dh.Id = dns.Id
754 dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode&0xF)
755 if dns.Response {
756 dh.Bits |= _QR
757 }
758 if dns.Authoritative {
759 dh.Bits |= _AA
760 }
761 if dns.Truncated {
762 dh.Bits |= _TC
763 }
764 if dns.RecursionDesired {
765 dh.Bits |= _RD
766 }
767 if dns.RecursionAvailable {
768 dh.Bits |= _RA
769 }
770 if dns.Zero {
771 dh.Bits |= _Z
772 }
773 if dns.AuthenticatedData {
774 dh.Bits |= _AD
775 }
776 if dns.CheckingDisabled {
777 dh.Bits |= _CD
778 }
779 780 dh.Qdcount = uint16(len(dns.Question))
781 dh.Ancount = uint16(len(dns.Answer))
782 dh.Nscount = uint16(len(dns.Ns))
783 dh.Arcount = uint16(len(dns.Extra))
784 785 // We need the uncompressed length here, because we first pack it and then compress it.
786 msg = buf
787 uncompressedLen := msgLenWithCompressionMap(dns, nil)
788 if packLen := uncompressedLen + 1; len(msg) < packLen {
789 msg = make([]byte, packLen)
790 }
791 792 // Pack it in: header and then the pieces.
793 off := 0
794 off, err = dh.pack(msg, off, compression, compress)
795 if err != nil {
796 return nil, err
797 }
798 for _, r := range dns.Question {
799 off, err = r.pack(msg, off, compression, compress)
800 if err != nil {
801 return nil, err
802 }
803 }
804 for _, r := range dns.Answer {
805 _, off, err = packRR(r, msg, off, compression, compress)
806 if err != nil {
807 return nil, err
808 }
809 }
810 for _, r := range dns.Ns {
811 _, off, err = packRR(r, msg, off, compression, compress)
812 if err != nil {
813 return nil, err
814 }
815 }
816 for _, r := range dns.Extra {
817 _, off, err = packRR(r, msg, off, compression, compress)
818 if err != nil {
819 return nil, err
820 }
821 }
822 return msg[:off], nil
823 }
824 825 func (dns *Msg) unpack(dh Header, msg []byte, off int) (err error) {
826 // If we are at the end of the message we should return *just* the
827 // header. This can still be useful to the caller. 9.9.9.9 sends these
828 // when responding with REFUSED for instance.
829 if off == len(msg) {
830 // reset sections before returning
831 dns.Question, dns.Answer, dns.Ns, dns.Extra = nil, nil, nil, nil
832 return nil
833 }
834 835 // Qdcount, Ancount, Nscount, Arcount can't be trusted, as they are
836 // attacker controlled. This means we can't use them to pre-allocate
837 // slices.
838 dns.Question = nil
839 for i := 0; i < int(dh.Qdcount); i++ {
840 off1 := off
841 var q Question
842 q, off, err = unpackQuestion(msg, off)
843 if err != nil {
844 return err
845 }
846 if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie!
847 dh.Qdcount = uint16(i)
848 break
849 }
850 dns.Question = append(dns.Question, q)
851 }
852 853 dns.Answer, off, err = unpackRRslice(int(dh.Ancount), msg, off)
854 // The header counts might have been wrong so we need to update it
855 dh.Ancount = uint16(len(dns.Answer))
856 if err == nil {
857 dns.Ns, off, err = unpackRRslice(int(dh.Nscount), msg, off)
858 }
859 // The header counts might have been wrong so we need to update it
860 dh.Nscount = uint16(len(dns.Ns))
861 if err == nil {
862 dns.Extra, _, err = unpackRRslice(int(dh.Arcount), msg, off)
863 }
864 // The header counts might have been wrong so we need to update it
865 dh.Arcount = uint16(len(dns.Extra))
866 867 // Set extended Rcode
868 if opt := dns.IsEdns0(); opt != nil {
869 dns.Rcode |= opt.ExtendedRcode()
870 }
871 872 // TODO(miek) make this an error?
873 // use PackOpt to let people tell how detailed the error reporting should be?
874 // if off != len(msg) {
875 // // println("dns: extra bytes in dns packet", off, "<", len(msg))
876 // }
877 return err
878 }
879 880 // Unpack unpacks a binary message to a Msg structure.
881 func (dns *Msg) Unpack(msg []byte) (err error) {
882 dh, off, err := unpackMsgHdr(msg, 0)
883 if err != nil {
884 return err
885 }
886 887 dns.setHdr(dh)
888 return dns.unpack(dh, msg, off)
889 }
890 891 // Convert a complete message to a string with dig-like output.
892 func (dns *Msg) String() string {
893 if dns == nil {
894 return "<nil> MsgHdr"
895 }
896 s := dns.MsgHdr.String() + " "
897 if dns.MsgHdr.Opcode == OpcodeUpdate {
898 s += "ZONE: " + strconv.Itoa(len(dns.Question)) + ", "
899 s += "PREREQ: " + strconv.Itoa(len(dns.Answer)) + ", "
900 s += "UPDATE: " + strconv.Itoa(len(dns.Ns)) + ", "
901 s += "ADDITIONAL: " + strconv.Itoa(len(dns.Extra)) + "\n"
902 } else {
903 s += "QUERY: " + strconv.Itoa(len(dns.Question)) + ", "
904 s += "ANSWER: " + strconv.Itoa(len(dns.Answer)) + ", "
905 s += "AUTHORITY: " + strconv.Itoa(len(dns.Ns)) + ", "
906 s += "ADDITIONAL: " + strconv.Itoa(len(dns.Extra)) + "\n"
907 }
908 opt := dns.IsEdns0()
909 if opt != nil {
910 // OPT PSEUDOSECTION
911 s += opt.String() + "\n"
912 }
913 if len(dns.Question) > 0 {
914 if dns.MsgHdr.Opcode == OpcodeUpdate {
915 s += "\n;; ZONE SECTION:\n"
916 } else {
917 s += "\n;; QUESTION SECTION:\n"
918 }
919 for _, r := range dns.Question {
920 s += r.String() + "\n"
921 }
922 }
923 if len(dns.Answer) > 0 {
924 if dns.MsgHdr.Opcode == OpcodeUpdate {
925 s += "\n;; PREREQUISITE SECTION:\n"
926 } else {
927 s += "\n;; ANSWER SECTION:\n"
928 }
929 for _, r := range dns.Answer {
930 if r != nil {
931 s += r.String() + "\n"
932 }
933 }
934 }
935 if len(dns.Ns) > 0 {
936 if dns.MsgHdr.Opcode == OpcodeUpdate {
937 s += "\n;; UPDATE SECTION:\n"
938 } else {
939 s += "\n;; AUTHORITY SECTION:\n"
940 }
941 for _, r := range dns.Ns {
942 if r != nil {
943 s += r.String() + "\n"
944 }
945 }
946 }
947 if len(dns.Extra) > 0 && (opt == nil || len(dns.Extra) > 1) {
948 s += "\n;; ADDITIONAL SECTION:\n"
949 for _, r := range dns.Extra {
950 if r != nil && r.Header().Rrtype != TypeOPT {
951 s += r.String() + "\n"
952 }
953 }
954 }
955 return s
956 }
957 958 // isCompressible returns whether the msg may be compressible.
959 func (dns *Msg) isCompressible() bool {
960 // If we only have one question, there is nothing we can ever compress.
961 return len(dns.Question) > 1 || len(dns.Answer) > 0 ||
962 len(dns.Ns) > 0 || len(dns.Extra) > 0
963 }
964 965 // Len returns the message length when in (un)compressed wire format.
966 // If dns.Compress is true compression it is taken into account. Len()
967 // is provided to be a faster way to get the size of the resulting packet,
968 // than packing it, measuring the size and discarding the buffer.
969 func (dns *Msg) Len() int {
970 // If this message can't be compressed, avoid filling the
971 // compression map and creating garbage.
972 if dns.Compress && dns.isCompressible() {
973 compression := make(map[string]struct{})
974 return msgLenWithCompressionMap(dns, compression)
975 }
976 977 return msgLenWithCompressionMap(dns, nil)
978 }
979 980 func msgLenWithCompressionMap(dns *Msg, compression map[string]struct{}) int {
981 l := headerSize
982 983 for _, r := range dns.Question {
984 l += r.len(l, compression)
985 }
986 for _, r := range dns.Answer {
987 if r != nil {
988 l += r.len(l, compression)
989 }
990 }
991 for _, r := range dns.Ns {
992 if r != nil {
993 l += r.len(l, compression)
994 }
995 }
996 for _, r := range dns.Extra {
997 if r != nil {
998 l += r.len(l, compression)
999 }
1000 }
1001 1002 return l
1003 }
1004 1005 func domainNameLen(s string, off int, compression map[string]struct{}, compress bool) int {
1006 if s == "" || s == "." {
1007 return 1
1008 }
1009 1010 escaped := strings.Contains(s, "\\")
1011 1012 if compression != nil && (compress || off < maxCompressionOffset) {
1013 // compressionLenSearch will insert the entry into the compression
1014 // map if it doesn't contain it.
1015 if l, ok := compressionLenSearch(compression, s, off); ok && compress {
1016 if escaped {
1017 return escapedNameLen(s[:l]) + 2
1018 }
1019 1020 return l + 2
1021 }
1022 }
1023 1024 if escaped {
1025 return escapedNameLen(s) + 1
1026 }
1027 1028 return len(s) + 1
1029 }
1030 1031 func escapedNameLen(s string) int {
1032 nameLen := len(s)
1033 for i := 0; i < len(s); i++ {
1034 if s[i] != '\\' {
1035 continue
1036 }
1037 1038 if isDDD(s[i+1:]) {
1039 nameLen -= 3
1040 i += 3
1041 } else {
1042 nameLen--
1043 i++
1044 }
1045 }
1046 1047 return nameLen
1048 }
1049 1050 func compressionLenSearch(c map[string]struct{}, s string, msgOff int) (int, bool) {
1051 for off, end := 0, false; !end; off, end = NextLabel(s, off) {
1052 if _, ok := c[s[off:]]; ok {
1053 return off, true
1054 }
1055 1056 if msgOff+off < maxCompressionOffset {
1057 c[s[off:]] = struct{}{}
1058 }
1059 }
1060 1061 return 0, false
1062 }
1063 1064 // Copy returns a new RR which is a deep-copy of r.
1065 func Copy(r RR) RR { return r.copy() }
1066 1067 // Len returns the length (in octets) of the uncompressed RR in wire format.
1068 func Len(r RR) int { return r.len(0, nil) }
1069 1070 // Copy returns a new *Msg which is a deep-copy of dns.
1071 func (dns *Msg) Copy() *Msg { return dns.CopyTo(new(Msg)) }
1072 1073 // CopyTo copies the contents to the provided message using a deep-copy and returns the copy.
1074 func (dns *Msg) CopyTo(r1 *Msg) *Msg {
1075 r1.MsgHdr = dns.MsgHdr
1076 r1.Compress = dns.Compress
1077 1078 if len(dns.Question) > 0 {
1079 // TODO(miek): Question is an immutable value, ok to do a shallow-copy
1080 r1.Question = cloneSlice(dns.Question)
1081 }
1082 1083 rrArr := make([]RR, len(dns.Answer)+len(dns.Ns)+len(dns.Extra))
1084 r1.Answer, rrArr = rrArr[:0:len(dns.Answer)], rrArr[len(dns.Answer):]
1085 r1.Ns, rrArr = rrArr[:0:len(dns.Ns)], rrArr[len(dns.Ns):]
1086 r1.Extra = rrArr[:0:len(dns.Extra)]
1087 1088 for _, r := range dns.Answer {
1089 r1.Answer = append(r1.Answer, r.copy())
1090 }
1091 1092 for _, r := range dns.Ns {
1093 r1.Ns = append(r1.Ns, r.copy())
1094 }
1095 1096 for _, r := range dns.Extra {
1097 r1.Extra = append(r1.Extra, r.copy())
1098 }
1099 1100 return r1
1101 }
1102 1103 func (q *Question) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
1104 off, err := packDomainName(q.Name, msg, off, compression, compress)
1105 if err != nil {
1106 return off, err
1107 }
1108 off, err = packUint16(q.Qtype, msg, off)
1109 if err != nil {
1110 return off, err
1111 }
1112 off, err = packUint16(q.Qclass, msg, off)
1113 if err != nil {
1114 return off, err
1115 }
1116 return off, nil
1117 }
1118 1119 func unpackQuestion(msg []byte, off int) (Question, int, error) {
1120 var (
1121 q Question
1122 err error
1123 )
1124 q.Name, off, err = UnpackDomainName(msg, off)
1125 if err != nil {
1126 return q, off, fmt.Errorf("bad question name: %w", err)
1127 }
1128 if off == len(msg) {
1129 return q, off, nil
1130 }
1131 q.Qtype, off, err = unpackUint16(msg, off)
1132 if err != nil {
1133 return q, off, fmt.Errorf("bad question qtype: %w", err)
1134 }
1135 if off == len(msg) {
1136 return q, off, nil
1137 }
1138 q.Qclass, off, err = unpackUint16(msg, off)
1139 if err != nil {
1140 return q, off, fmt.Errorf("bad question qclass: %w", err)
1141 }
1142 1143 if off == len(msg) {
1144 return q, off, nil
1145 }
1146 1147 return q, off, nil
1148 }
1149 1150 func (dh *Header) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
1151 off, err := packUint16(dh.Id, msg, off)
1152 if err != nil {
1153 return off, err
1154 }
1155 off, err = packUint16(dh.Bits, msg, off)
1156 if err != nil {
1157 return off, err
1158 }
1159 off, err = packUint16(dh.Qdcount, msg, off)
1160 if err != nil {
1161 return off, err
1162 }
1163 off, err = packUint16(dh.Ancount, msg, off)
1164 if err != nil {
1165 return off, err
1166 }
1167 off, err = packUint16(dh.Nscount, msg, off)
1168 if err != nil {
1169 return off, err
1170 }
1171 off, err = packUint16(dh.Arcount, msg, off)
1172 if err != nil {
1173 return off, err
1174 }
1175 return off, nil
1176 }
1177 1178 func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
1179 var (
1180 dh Header
1181 err error
1182 )
1183 dh.Id, off, err = unpackUint16(msg, off)
1184 if err != nil {
1185 return dh, off, fmt.Errorf("bad header id: %w", err)
1186 }
1187 dh.Bits, off, err = unpackUint16(msg, off)
1188 if err != nil {
1189 return dh, off, fmt.Errorf("bad header bits: %w", err)
1190 }
1191 dh.Qdcount, off, err = unpackUint16(msg, off)
1192 if err != nil {
1193 return dh, off, fmt.Errorf("bad header question count: %w", err)
1194 }
1195 dh.Ancount, off, err = unpackUint16(msg, off)
1196 if err != nil {
1197 return dh, off, fmt.Errorf("bad header answer count: %w", err)
1198 }
1199 dh.Nscount, off, err = unpackUint16(msg, off)
1200 if err != nil {
1201 return dh, off, fmt.Errorf("bad header ns count: %w", err)
1202 }
1203 dh.Arcount, off, err = unpackUint16(msg, off)
1204 if err != nil {
1205 return dh, off, fmt.Errorf("bad header extra count: %w", err)
1206 }
1207 return dh, off, nil
1208 }
1209 1210 // setHdr set the header in the dns using the binary data in dh.
1211 func (dns *Msg) setHdr(dh Header) {
1212 dns.Id = dh.Id
1213 dns.Response = dh.Bits&_QR != 0
1214 dns.Opcode = int(dh.Bits>>11) & 0xF
1215 dns.Authoritative = dh.Bits&_AA != 0
1216 dns.Truncated = dh.Bits&_TC != 0
1217 dns.RecursionDesired = dh.Bits&_RD != 0
1218 dns.RecursionAvailable = dh.Bits&_RA != 0
1219 dns.Zero = dh.Bits&_Z != 0 // _Z covers the zero bit, which should be zero; not sure why we set it to the opposite.
1220 dns.AuthenticatedData = dh.Bits&_AD != 0
1221 dns.CheckingDisabled = dh.Bits&_CD != 0
1222 dns.Rcode = int(dh.Bits & 0xF)
1223 }
1224