message_test.go raw

   1  package wire
   2  
   3  import (
   4  	"bytes"
   5  	"encoding/binary"
   6  	"io"
   7  	"net"
   8  	"reflect"
   9  	"testing"
  10  	"time"
  11  	
  12  	"github.com/davecgh/go-spew/spew"
  13  	
  14  	"github.com/p9c/p9/pkg/chainhash"
  15  )
  16  
  17  // makeHeader is a convenience function to make a message header in the form of a byte slice.  It is used to force errors when reading messages.
  18  func makeHeader(
  19  	btcnet BitcoinNet, command string,
  20  	payloadLen uint32, checksum uint32,
  21  ) []byte {
  22  	// The length of a bitcoin message header is 24 bytes.
  23  	// 4 byte magic number of the bitcoin network + 12 byte command + 4 byte payload length + 4 byte checksum.
  24  	buf := make([]byte, 24)
  25  	binary.LittleEndian.PutUint32(buf, uint32(btcnet))
  26  	copy(buf[4:], command)
  27  	binary.LittleEndian.PutUint32(buf[16:], payloadLen)
  28  	binary.LittleEndian.PutUint32(buf[20:], checksum)
  29  	return buf
  30  }
  31  
  32  // TestMessage tests the Read/WriteMessage and Read/WriteMessageN API.
  33  func TestMessage(t *testing.T) {
  34  	pver := ProtocolVersion
  35  	// Create the various types of messages to test. MsgVersion.
  36  	addrYou := &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 11047}
  37  	you := NewNetAddress(addrYou, SFNodeNetwork)
  38  	you.Timestamp = time.Time{} // Version message has zero value timestamp.
  39  	addrMe := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 11047}
  40  	me := NewNetAddress(addrMe, SFNodeNetwork)
  41  	me.Timestamp = time.Time{} // Version message has zero value timestamp.
  42  	msgVersion := NewMsgVersion(me, you, 123123, 0)
  43  	msgVerack := NewMsgVerAck()
  44  	msgGetAddr := NewMsgGetAddr()
  45  	msgAddr := NewMsgAddr()
  46  	msgGetBlocks := NewMsgGetBlocks(&chainhash.Hash{})
  47  	msgBlock := &blockOne
  48  	msgInv := NewMsgInv()
  49  	msgGetData := NewMsgGetData()
  50  	msgNotFound := NewMsgNotFound()
  51  	msgTx := NewMsgTx(1)
  52  	msgPing := NewMsgPing(123123)
  53  	msgPong := NewMsgPong(123123)
  54  	msgGetHeaders := NewMsgGetHeaders()
  55  	msgHeaders := NewMsgHeaders()
  56  	msgAlert := NewMsgAlert([]byte("payload"), []byte("signature"))
  57  	msgMemPool := NewMsgMemPool()
  58  	msgFilterAdd := NewMsgFilterAdd([]byte{0x01})
  59  	msgFilterClear := NewMsgFilterClear()
  60  	msgFilterLoad := NewMsgFilterLoad([]byte{0x01}, 10, 0, BloomUpdateNone)
  61  	bh := NewBlockHeader(1, &chainhash.Hash{}, &chainhash.Hash{}, 0, 0)
  62  	msgMerkleBlock := NewMsgMerkleBlock(bh)
  63  	msgReject := NewMsgReject("block", RejectDuplicate, "duplicate block")
  64  	msgGetCFilters := NewMsgGetCFilters(GCSFilterRegular, 0, &chainhash.Hash{})
  65  	msgGetCFHeaders := NewMsgGetCFHeaders(GCSFilterRegular, 0, &chainhash.Hash{})
  66  	msgGetCFCheckpt := NewMsgGetCFCheckpt(GCSFilterRegular, &chainhash.Hash{})
  67  	msgCFilter := NewMsgCFilter(
  68  		GCSFilterRegular, &chainhash.Hash{},
  69  		[]byte("payload"),
  70  	)
  71  	msgCFHeaders := NewMsgCFHeaders()
  72  	msgCFCheckpt := NewMsgCFCheckpt(GCSFilterRegular, &chainhash.Hash{}, 0)
  73  	tests := []struct {
  74  		in     Message    // value to encode
  75  		out    Message    // Expected decoded value
  76  		pver   uint32     // Protocol version for wire encoding
  77  		btcnet BitcoinNet // Network to use for wire encoding
  78  		bytes  int        // Expected num bytes read/written
  79  	}{
  80  		{msgVersion, msgVersion, pver, MainNet, 125},
  81  		{msgVerack, msgVerack, pver, MainNet, 24},
  82  		{msgGetAddr, msgGetAddr, pver, MainNet, 24},
  83  		{msgAddr, msgAddr, pver, MainNet, 25},
  84  		{msgGetBlocks, msgGetBlocks, pver, MainNet, 61},
  85  		{msgBlock, msgBlock, pver, MainNet, 239},
  86  		{msgInv, msgInv, pver, MainNet, 25},
  87  		{msgGetData, msgGetData, pver, MainNet, 25},
  88  		{msgNotFound, msgNotFound, pver, MainNet, 25},
  89  		{msgTx, msgTx, pver, MainNet, 34},
  90  		{msgPing, msgPing, pver, MainNet, 32},
  91  		{msgPong, msgPong, pver, MainNet, 32},
  92  		{msgGetHeaders, msgGetHeaders, pver, MainNet, 61},
  93  		{msgHeaders, msgHeaders, pver, MainNet, 25},
  94  		{msgAlert, msgAlert, pver, MainNet, 42},
  95  		{msgMemPool, msgMemPool, pver, MainNet, 24},
  96  		{msgFilterAdd, msgFilterAdd, pver, MainNet, 26},
  97  		{msgFilterClear, msgFilterClear, pver, MainNet, 24},
  98  		{msgFilterLoad, msgFilterLoad, pver, MainNet, 35},
  99  		{msgMerkleBlock, msgMerkleBlock, pver, MainNet, 110},
 100  		{msgReject, msgReject, pver, MainNet, 79},
 101  		{msgGetCFilters, msgGetCFilters, pver, MainNet, 61},
 102  		{msgGetCFHeaders, msgGetCFHeaders, pver, MainNet, 61},
 103  		{msgGetCFCheckpt, msgGetCFCheckpt, pver, MainNet, 57},
 104  		{msgCFilter, msgCFilter, pver, MainNet, 65},
 105  		{msgCFHeaders, msgCFHeaders, pver, MainNet, 90},
 106  		{msgCFCheckpt, msgCFCheckpt, pver, MainNet, 58},
 107  	}
 108  	t.Logf("Running %d tests", len(tests))
 109  	var msg Message
 110  	for i, test := range tests {
 111  		// Encode to wire format.
 112  		var buf bytes.Buffer
 113  		nw, e := WriteMessageN(&buf, test.in, test.pver, test.btcnet)
 114  		if e != nil {
 115  			t.Errorf("WriteMessage #%d error %v", i, e)
 116  			continue
 117  		}
 118  		// Ensure the number of bytes written match the expected value.
 119  		if nw != test.bytes {
 120  			t.Errorf(
 121  				"WriteMessage #%d unexpected num bytes "+
 122  					"written - got %d, want %d", i, nw, test.bytes,
 123  			)
 124  		}
 125  		// Decode from wire format.
 126  		rbuf := bytes.NewReader(buf.Bytes())
 127  		var nr int
 128  		nr, msg, _, e = ReadMessageN(rbuf, test.pver, test.btcnet)
 129  		if e != nil {
 130  			t.Errorf(
 131  				"ReadMessage #%d error %v, msg %v", i, e,
 132  				spew.Sdump(msg),
 133  			)
 134  			continue
 135  		}
 136  		if !reflect.DeepEqual(msg, test.out) {
 137  			t.Errorf(
 138  				"ReadMessage #%d\n got: %v want: %v", i,
 139  				spew.Sdump(msg), spew.Sdump(test.out),
 140  			)
 141  			continue
 142  		}
 143  		// Ensure the number of bytes read match the expected value.
 144  		if nr != test.bytes {
 145  			t.Errorf(
 146  				"ReadMessage #%d unexpected num bytes read - "+
 147  					"got %d, want %d", i, nr, test.bytes,
 148  			)
 149  		}
 150  	}
 151  	// Do the same thing for Read/WriteMessage, but ignore the bytes since they don't return them.
 152  	t.Logf("Running %d tests", len(tests))
 153  	for i, test := range tests {
 154  		// Encode to wire format.
 155  		var buf bytes.Buffer
 156  		e := WriteMessage(&buf, test.in, test.pver, test.btcnet)
 157  		if e != nil {
 158  			t.Errorf("WriteMessage #%d error %v", i, e)
 159  			continue
 160  		}
 161  		// Decode from wire format.
 162  		rbuf := bytes.NewReader(buf.Bytes())
 163  		msg, _, e = ReadMessage(rbuf, test.pver, test.btcnet)
 164  		if e != nil {
 165  			t.Errorf(
 166  				"ReadMessage #%d error %v, msg %v", i, e,
 167  				spew.Sdump(msg),
 168  			)
 169  			continue
 170  		}
 171  		if !reflect.DeepEqual(msg, test.out) {
 172  			t.Errorf(
 173  				"ReadMessage #%d\n got: %v want: %v", i,
 174  				spew.Sdump(msg), spew.Sdump(test.out),
 175  			)
 176  			continue
 177  		}
 178  	}
 179  }
 180  
 181  // TestReadMessageWireErrors performs negative tests against wire decoding into concrete messages to confirm error paths
 182  // work correctly.
 183  func TestReadMessageWireErrors(t *testing.T) {
 184  	pver := ProtocolVersion
 185  	btcnet := MainNet
 186  	// Ensure message errors are as expected with no function specified.
 187  	wantErr := "something bad happened"
 188  	testErr := MessageError{Description: wantErr}
 189  	if testErr.Error() != wantErr {
 190  		t.Errorf(
 191  			"MessageError: wrong error - got %v, want %v",
 192  			testErr.Error(), wantErr,
 193  		)
 194  	}
 195  	// Ensure message errors are as expected with a function specified.
 196  	wantFunc := "foo"
 197  	testErr = MessageError{Func: wantFunc, Description: wantErr}
 198  	if testErr.Error() != wantFunc+": "+wantErr {
 199  		t.Errorf(
 200  			"MessageError: wrong error - got %v, want %v",
 201  			testErr.Error(), wantErr,
 202  		)
 203  	}
 204  	// Wire encoded bytes for main and testnet3 networks magic identifiers.
 205  	testNet3Bytes := makeHeader(TestNet3, "", 0, 0)
 206  	// Wire encoded bytes for a message that exceeds max overall message length.
 207  	mpl := uint32(MaxMessagePayload)
 208  	exceedMaxPayloadBytes := makeHeader(btcnet, "getaddr", mpl+1, 0)
 209  	// Wire encoded bytes for a command which is invalid utf-8.
 210  	badCommandBytes := makeHeader(btcnet, "bogus", 0, 0)
 211  	badCommandBytes[4] = 0x81
 212  	// Wire encoded bytes for a command which is valid, but not supported.
 213  	unsupportedCommandBytes := makeHeader(btcnet, "bogus", 0, 0)
 214  	// Wire encoded bytes for a message which exceeds the max payload for a specific message type.
 215  	exceedTypePayloadBytes := makeHeader(btcnet, "getaddr", 1, 0)
 216  	// Wire encoded bytes for a message which does not deliver the full payload according to the header length.
 217  	shortPayloadBytes := makeHeader(btcnet, "version", 115, 0)
 218  	// Wire encoded bytes for a message with a bad checksum.
 219  	badChecksumBytes := makeHeader(btcnet, "version", 2, 0xbeef)
 220  	badChecksumBytes = append(badChecksumBytes, []byte{0x0, 0x0}...)
 221  	// Wire encoded bytes for a message which has a valid header, but is the wrong format. An addr starts with a varint
 222  	// of the number of contained in the message. Claim there is two, but don't provide them. At the same time, forge
 223  	// the header fields so the message is otherwise accurate.
 224  	badMessageBytes := makeHeader(btcnet, "addr", 1, 0xeaadc31c)
 225  	badMessageBytes = append(badMessageBytes, 0x2)
 226  	// Wire encoded bytes for a message which the header claims has 15k bytes of data to discard.
 227  	discardBytes := makeHeader(btcnet, "bogus", 15*1024, 0)
 228  	tests := []struct {
 229  		buf     []byte     // Wire encoding
 230  		pver    uint32     // Protocol version for wire encoding
 231  		btcnet  BitcoinNet // Bitcoin network for wire encoding
 232  		max     int        // Max size of fixed buffer to induce errors
 233  		readErr error      // Expected read error
 234  		bytes   int        // Expected num bytes read
 235  	}{
 236  		// Latest protocol version with intentional read errors.
 237  		// Short header.
 238  		{
 239  			[]byte{},
 240  			pver,
 241  			btcnet,
 242  			0,
 243  			io.EOF,
 244  			0,
 245  		},
 246  		// Wrong network.  Want MainNet, but giving TestNet3.
 247  		{
 248  			testNet3Bytes,
 249  			pver,
 250  			btcnet,
 251  			len(testNet3Bytes),
 252  			&MessageError{},
 253  			24,
 254  		},
 255  		// Exceed max overall message payload length.
 256  		{
 257  			exceedMaxPayloadBytes,
 258  			pver,
 259  			btcnet,
 260  			len(exceedMaxPayloadBytes),
 261  			&MessageError{},
 262  			24,
 263  		},
 264  		// Invalid UTF-8 command.
 265  		{
 266  			badCommandBytes,
 267  			pver,
 268  			btcnet,
 269  			len(badCommandBytes),
 270  			&MessageError{},
 271  			24,
 272  		},
 273  		// Valid, but unsupported command.
 274  		{
 275  			unsupportedCommandBytes,
 276  			pver,
 277  			btcnet,
 278  			len(unsupportedCommandBytes),
 279  			&MessageError{},
 280  			24,
 281  		},
 282  		// Exceed max allowed payload for a message of a specific type.
 283  		{
 284  			exceedTypePayloadBytes,
 285  			pver,
 286  			btcnet,
 287  			len(exceedTypePayloadBytes),
 288  			&MessageError{},
 289  			24,
 290  		},
 291  		// Message with a payload shorter than the header indicates.
 292  		{
 293  			shortPayloadBytes,
 294  			pver,
 295  			btcnet,
 296  			len(shortPayloadBytes),
 297  			io.EOF,
 298  			24,
 299  		},
 300  		// Message with a bad checksum.
 301  		{
 302  			badChecksumBytes,
 303  			pver,
 304  			btcnet,
 305  			len(badChecksumBytes),
 306  			&MessageError{},
 307  			26,
 308  		},
 309  		// Message with a valid header, but wrong format.
 310  		{
 311  			badMessageBytes,
 312  			pver,
 313  			btcnet,
 314  			len(badMessageBytes),
 315  			io.EOF,
 316  			25,
 317  		},
 318  		// 15k bytes of data to discard.
 319  		{
 320  			discardBytes,
 321  			pver,
 322  			btcnet,
 323  			len(discardBytes),
 324  			&MessageError{},
 325  			24,
 326  		},
 327  	}
 328  	t.Logf("Running %d tests", len(tests))
 329  	for i, test := range tests {
 330  		// Decode from wire format.
 331  		r := newFixedReader(test.max, test.buf)
 332  		var nr int
 333  		var e error
 334  		nr, _, _, e = ReadMessageN(r, test.pver, test.btcnet)
 335  		if reflect.TypeOf(e) != reflect.TypeOf(test.readErr) {
 336  			t.Errorf(
 337  				"ReadMessage #%d wrong error got: %v <%T>, "+
 338  					"want: %T", i, e, e, test.readErr,
 339  			)
 340  			continue
 341  		}
 342  		// Ensure the number of bytes written match the expected value.
 343  		if nr != test.bytes {
 344  			t.Errorf(
 345  				"ReadMessage #%d unexpected num bytes read - "+
 346  					"got %d, want %d", i, nr, test.bytes,
 347  			)
 348  		}
 349  		// For errors which are not of type MessageError, check them for equality.
 350  		if _, ok := e.(*MessageError); !ok {
 351  			if e != test.readErr {
 352  				t.Errorf(
 353  					"ReadMessage #%d wrong error got: %v <%T>, "+
 354  						"want: %v <%T>", i, e, e,
 355  					test.readErr, test.readErr,
 356  				)
 357  				continue
 358  			}
 359  		}
 360  	}
 361  }
 362  
 363  // TestWriteMessageWireErrors performs negative tests against wire encoding from concrete messages to confirm error
 364  // paths work correctly.
 365  func TestWriteMessageWireErrors(t *testing.T) {
 366  	pver := ProtocolVersion
 367  	btcnet := MainNet
 368  	wireErr := &MessageError{}
 369  	// Fake message with a command that is too long.
 370  	badCommandMsg := &fakeMessage{command: "somethingtoolong"}
 371  	// Fake message with a problem during encoding
 372  	encodeErrMsg := &fakeMessage{forceEncodeErr: true}
 373  	// Fake message that has payload which exceeds max overall message size.
 374  	exceedOverallPayload := make([]byte, MaxMessagePayload+1)
 375  	exceedOverallPayloadErrMsg := &fakeMessage{payload: exceedOverallPayload}
 376  	// Fake message that has payload which exceeds max allowed per message.
 377  	exceedPayload := make([]byte, 1)
 378  	exceedPayloadErrMsg := &fakeMessage{payload: exceedPayload, forceLenErr: true}
 379  	// Fake message that is used to force errors in the header and payload writes.
 380  	bogusPayload := []byte{0x01, 0x02, 0x03, 0x04}
 381  	bogusMsg := &fakeMessage{command: "bogus", payload: bogusPayload}
 382  	tests := []struct {
 383  		msg    Message    // Message to encode
 384  		pver   uint32     // Protocol version for wire encoding
 385  		btcnet BitcoinNet // Bitcoin network for wire encoding
 386  		max    int        // Max size of fixed buffer to induce errors
 387  		err    error      // Expected error
 388  		bytes  int        // Expected num bytes written
 389  	}{
 390  		// Command too long.
 391  		{badCommandMsg, pver, btcnet, 0, wireErr, 0},
 392  		// Force error in payload encode.
 393  		{encodeErrMsg, pver, btcnet, 0, wireErr, 0},
 394  		// Force error due to exceeding max overall message payload size.
 395  		{exceedOverallPayloadErrMsg, pver, btcnet, 0, wireErr, 0},
 396  		// Force error due to exceeding max payload for message type.
 397  		{exceedPayloadErrMsg, pver, btcnet, 0, wireErr, 0},
 398  		// Force error in header write.
 399  		{bogusMsg, pver, btcnet, 0, io.ErrShortWrite, 0},
 400  		// Force error in payload write.
 401  		{bogusMsg, pver, btcnet, 24, io.ErrShortWrite, 24},
 402  	}
 403  	t.Logf("Running %d tests", len(tests))
 404  	for i, test := range tests {
 405  		// Encode wire format.
 406  		w := newFixedWriter(test.max)
 407  		nw, e := WriteMessageN(w, test.msg, test.pver, test.btcnet)
 408  		if reflect.TypeOf(e) != reflect.TypeOf(test.err) {
 409  			t.Errorf(
 410  				"WriteMessage #%d wrong error got: %v <%T>, "+
 411  					"want: %T", i, e, e, test.err,
 412  			)
 413  			continue
 414  		}
 415  		// Ensure the number of bytes written match the expected value.
 416  		if nw != test.bytes {
 417  			t.Errorf(
 418  				"WriteMessage #%d unexpected num bytes "+
 419  					"written - got %d, want %d", i, nw, test.bytes,
 420  			)
 421  		}
 422  		// For errors which are not of type MessageError, check them for equality.
 423  		if _, ok := e.(*MessageError); !ok {
 424  			if e != test.err {
 425  				t.Errorf(
 426  					"ReadMessage #%d wrong error got: %v <%T>, "+
 427  						"want: %v <%T>", i, e, e,
 428  					test.err, test.err,
 429  				)
 430  				continue
 431  			}
 432  		}
 433  	}
 434  }
 435