xfr.go raw

   1  package dns
   2  
   3  import (
   4  	"crypto/tls"
   5  	"fmt"
   6  	"time"
   7  )
   8  
   9  // Envelope is used when doing a zone transfer with a remote server.
  10  type Envelope struct {
  11  	RR    []RR  // The set of RRs in the answer section of the xfr reply message.
  12  	Error error // If something went wrong, this contains the error.
  13  }
  14  
  15  // A Transfer defines parameters that are used during a zone transfer.
  16  type Transfer struct {
  17  	*Conn
  18  	DialTimeout    time.Duration     // net.DialTimeout, defaults to 2 seconds
  19  	ReadTimeout    time.Duration     // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds
  20  	WriteTimeout   time.Duration     // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds
  21  	TsigProvider   TsigProvider      // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
  22  	TsigSecret     map[string]string // Secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
  23  	tsigTimersOnly bool
  24  	TLS            *tls.Config // TLS config. If Xfr over TLS will be attempted
  25  }
  26  
  27  func (t *Transfer) tsigProvider() TsigProvider {
  28  	if t.TsigProvider != nil {
  29  		return t.TsigProvider
  30  	}
  31  	if t.TsigSecret != nil {
  32  		return tsigSecretProvider(t.TsigSecret)
  33  	}
  34  	return nil
  35  }
  36  
  37  // TODO: Think we need to away to stop the transfer
  38  
  39  // In performs an incoming transfer with the server in a.
  40  // If you would like to set the source IP, or some other attribute
  41  // of a Dialer for a Transfer, you can do so by specifying the attributes
  42  // in the Transfer.Conn:
  43  //
  44  //	d := net.Dialer{LocalAddr: transfer_source}
  45  //	con, err := d.Dial("tcp", master)
  46  //	dnscon := &dns.Conn{Conn:con}
  47  //	transfer = &dns.Transfer{Conn: dnscon}
  48  //	channel, err := transfer.In(message, master)
  49  func (t *Transfer) In(q *Msg, a string) (env chan *Envelope, err error) {
  50  	switch q.Question[0].Qtype {
  51  	case TypeAXFR, TypeIXFR:
  52  	default:
  53  		return nil, &Error{"unsupported question type"}
  54  	}
  55  
  56  	timeout := dnsTimeout
  57  	if t.DialTimeout != 0 {
  58  		timeout = t.DialTimeout
  59  	}
  60  
  61  	if t.Conn == nil {
  62  		if t.TLS != nil {
  63  			t.Conn, err = DialTimeoutWithTLS("tcp-tls", a, t.TLS, timeout)
  64  		} else {
  65  			t.Conn, err = DialTimeout("tcp", a, timeout)
  66  		}
  67  		if err != nil {
  68  			return nil, err
  69  		}
  70  	}
  71  
  72  	if err := t.WriteMsg(q); err != nil {
  73  		return nil, err
  74  	}
  75  
  76  	env = make(chan *Envelope)
  77  	switch q.Question[0].Qtype {
  78  	case TypeAXFR:
  79  		go t.inAxfr(q, env)
  80  	case TypeIXFR:
  81  		go t.inIxfr(q, env)
  82  	}
  83  
  84  	return env, nil
  85  }
  86  
  87  func (t *Transfer) inAxfr(q *Msg, c chan *Envelope) {
  88  	first := true
  89  	defer func() {
  90  		// First close the connection, then the channel. This allows functions blocked on
  91  		// the channel to assume that the connection is closed and no further operations are
  92  		// pending when they resume.
  93  		t.Close()
  94  		close(c)
  95  	}()
  96  	timeout := dnsTimeout
  97  	if t.ReadTimeout != 0 {
  98  		timeout = t.ReadTimeout
  99  	}
 100  	for {
 101  		t.Conn.SetReadDeadline(time.Now().Add(timeout))
 102  		in, err := t.ReadMsg()
 103  		if err != nil {
 104  			c <- &Envelope{nil, err}
 105  			return
 106  		}
 107  		if q.Id != in.Id {
 108  			c <- &Envelope{in.Answer, ErrId}
 109  			return
 110  		}
 111  		if first {
 112  			if in.Rcode != RcodeSuccess {
 113  				c <- &Envelope{in.Answer, &Error{err: fmt.Sprintf(errXFR, in.Rcode)}}
 114  				return
 115  			}
 116  			if !isSOAFirst(in) {
 117  				c <- &Envelope{in.Answer, ErrSoa}
 118  				return
 119  			}
 120  			first = !first
 121  			// only one answer that is SOA, receive more
 122  			if len(in.Answer) == 1 {
 123  				t.tsigTimersOnly = true
 124  				c <- &Envelope{in.Answer, nil}
 125  				continue
 126  			}
 127  		}
 128  
 129  		if !first {
 130  			t.tsigTimersOnly = true // Subsequent envelopes use this.
 131  			if isSOALast(in) {
 132  				c <- &Envelope{in.Answer, nil}
 133  				return
 134  			}
 135  			c <- &Envelope{in.Answer, nil}
 136  		}
 137  	}
 138  }
 139  
 140  func (t *Transfer) inIxfr(q *Msg, c chan *Envelope) {
 141  	var serial uint32 // The first serial seen is the current server serial
 142  	axfr := true
 143  	n := 0
 144  	qser := q.Ns[0].(*SOA).Serial
 145  	defer func() {
 146  		// First close the connection, then the channel. This allows functions blocked on
 147  		// the channel to assume that the connection is closed and no further operations are
 148  		// pending when they resume.
 149  		t.Close()
 150  		close(c)
 151  	}()
 152  	timeout := dnsTimeout
 153  	if t.ReadTimeout != 0 {
 154  		timeout = t.ReadTimeout
 155  	}
 156  	for {
 157  		t.SetReadDeadline(time.Now().Add(timeout))
 158  		in, err := t.ReadMsg()
 159  		if err != nil {
 160  			c <- &Envelope{nil, err}
 161  			return
 162  		}
 163  		if q.Id != in.Id {
 164  			c <- &Envelope{in.Answer, ErrId}
 165  			return
 166  		}
 167  		if in.Rcode != RcodeSuccess {
 168  			c <- &Envelope{in.Answer, &Error{err: fmt.Sprintf(errXFR, in.Rcode)}}
 169  			return
 170  		}
 171  		if n == 0 {
 172  			// Check if the returned answer is ok
 173  			if !isSOAFirst(in) {
 174  				c <- &Envelope{in.Answer, ErrSoa}
 175  				return
 176  			}
 177  			// This serial is important
 178  			serial = in.Answer[0].(*SOA).Serial
 179  			// Check if there are no changes in zone
 180  			if qser >= serial {
 181  				c <- &Envelope{in.Answer, nil}
 182  				return
 183  			}
 184  		}
 185  		// Now we need to check each message for SOA records, to see what we need to do
 186  		t.tsigTimersOnly = true
 187  		for _, rr := range in.Answer {
 188  			if v, ok := rr.(*SOA); ok {
 189  				if v.Serial == serial {
 190  					n++
 191  					// quit if it's a full axfr or the servers' SOA is repeated the third time
 192  					if axfr && n == 2 || n == 3 {
 193  						c <- &Envelope{in.Answer, nil}
 194  						return
 195  					}
 196  				} else if axfr {
 197  					// it's an ixfr
 198  					axfr = false
 199  				}
 200  			}
 201  		}
 202  		c <- &Envelope{in.Answer, nil}
 203  	}
 204  }
 205  
 206  // Out performs an outgoing transfer with the client connecting in w.
 207  // Basic use pattern:
 208  //
 209  //	ch := make(chan *dns.Envelope)
 210  //	tr := new(dns.Transfer)
 211  //	var wg sync.WaitGroup
 212  //	wg.Add(1)
 213  //	go func() {
 214  //		tr.Out(w, r, ch)
 215  //		wg.Done()
 216  //	}()
 217  //	ch <- &dns.Envelope{RR: []dns.RR{soa, rr1, rr2, rr3, soa}}
 218  //	close(ch)
 219  //	wg.Wait() // wait until everything is written out
 220  //	w.Close() // close connection
 221  //
 222  // The server is responsible for sending the correct sequence of RRs through the channel ch.
 223  func (t *Transfer) Out(w ResponseWriter, q *Msg, ch chan *Envelope) error {
 224  	for x := range ch {
 225  		r := new(Msg)
 226  		// Compress?
 227  		r.SetReply(q)
 228  		r.Authoritative = true
 229  		// assume it fits TODO(miek): fix
 230  		r.Answer = append(r.Answer, x.RR...)
 231  		if tsig := q.IsTsig(); tsig != nil && w.TsigStatus() == nil {
 232  			r.SetTsig(tsig.Hdr.Name, tsig.Algorithm, tsig.Fudge, time.Now().Unix())
 233  		}
 234  		if err := w.WriteMsg(r); err != nil {
 235  			return err
 236  		}
 237  		w.TsigTimersOnly(true)
 238  	}
 239  	return nil
 240  }
 241  
 242  // ReadMsg reads a message from the transfer connection t.
 243  func (t *Transfer) ReadMsg() (*Msg, error) {
 244  	m := new(Msg)
 245  	p := make([]byte, MaxMsgSize)
 246  	n, err := t.Read(p)
 247  	if err != nil && n == 0 {
 248  		return nil, err
 249  	}
 250  	p = p[:n]
 251  	if err := m.Unpack(p); err != nil {
 252  		return nil, err
 253  	}
 254  
 255  	if tp := t.tsigProvider(); tp != nil {
 256  		// Need to work on the original message p, as that was used to calculate the tsig.
 257  		err = TsigVerifyWithProvider(p, tp, t.tsigRequestMAC, t.tsigTimersOnly)
 258  		if ts := m.IsTsig(); ts != nil {
 259  			t.tsigRequestMAC = ts.MAC
 260  		}
 261  	}
 262  	return m, err
 263  }
 264  
 265  // WriteMsg writes a message through the transfer connection t.
 266  func (t *Transfer) WriteMsg(m *Msg) (err error) {
 267  	var out []byte
 268  	if ts, tp := m.IsTsig(), t.tsigProvider(); ts != nil && tp != nil {
 269  		out, t.tsigRequestMAC, err = TsigGenerateWithProvider(m, tp, t.tsigRequestMAC, t.tsigTimersOnly)
 270  	} else {
 271  		out, err = m.Pack()
 272  	}
 273  	if err != nil {
 274  		return err
 275  	}
 276  	_, err = t.Write(out)
 277  	return err
 278  }
 279  
 280  func isSOAFirst(in *Msg) bool {
 281  	return len(in.Answer) > 0 &&
 282  		in.Answer[0].Header().Rrtype == TypeSOA
 283  }
 284  
 285  func isSOALast(in *Msg) bool {
 286  	return len(in.Answer) > 0 &&
 287  		in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA
 288  }
 289  
 290  const errXFR = "bad xfr rcode: %d"
 291