merkleblock.go raw

   1  package bloom
   2  
   3  import (
   4  	"github.com/p9c/p9/pkg/block"
   5  	"github.com/p9c/p9/pkg/blockchain"
   6  	"github.com/p9c/p9/pkg/chainhash"
   7  	"github.com/p9c/p9/pkg/wire"
   8  )
   9  
  10  // merkleBlock is used to house intermediate information needed to generate a wire.MsgMerkleBlock according to a filter.
  11  type merkleBlock struct {
  12  	numTx       uint32
  13  	allHashes   []*chainhash.Hash
  14  	finalHashes []*chainhash.Hash
  15  	matchedBits []byte
  16  	bits        []byte
  17  }
  18  
  19  // calcTreeWidth calculates and returns the the number of nodes (width) or a merkle tree at the given depth-first
  20  // height.
  21  func (m *merkleBlock) calcTreeWidth(height uint32) uint32 {
  22  	return (m.numTx + (1 << height) - 1) >> height
  23  }
  24  
  25  // calcHash returns the hash for a sub-tree given a depth-first height and node position.
  26  func (m *merkleBlock) calcHash(height, pos uint32) *chainhash.Hash {
  27  	if height == 0 {
  28  		return m.allHashes[pos]
  29  	}
  30  	var right *chainhash.Hash
  31  	left := m.calcHash(height-1, pos*2)
  32  	if pos*2+1 < m.calcTreeWidth(height-1) {
  33  		right = m.calcHash(height-1, pos*2+1)
  34  	} else {
  35  		right = left
  36  	}
  37  	return blockchain.HashMerkleBranches(left, right)
  38  }
  39  
  40  // traverseAndBuild builds a partial merkle tree using a recursive depth-first approach. As it calculates the hashes, it
  41  // also saves whether or not each node is a parent node and a list of final hashes to be included in the merkle block.
  42  func (m *merkleBlock) traverseAndBuild(height, pos uint32) {
  43  	// Determine whether this node is a parent of a matched node.
  44  	var isParent byte
  45  	for i := pos << height; i < (pos+1)<<height && i < m.numTx; i++ {
  46  		isParent |= m.matchedBits[i]
  47  	}
  48  	m.bits = append(m.bits, isParent)
  49  	// When the node is a leaf node or not a parent of a matched node, append the hash to the list that will be part of
  50  	// the final merkle block.
  51  	if height == 0 || isParent == 0x00 {
  52  		m.finalHashes = append(m.finalHashes, m.calcHash(height, pos))
  53  		return
  54  	}
  55  	// At this point, the node is an internal node and it is the parent of of an included leaf node. Descend into the
  56  	// left child and process its sub-tree.
  57  	m.traverseAndBuild(height-1, pos*2)
  58  	// Descend into the right child and process its sub-tree if there is one.
  59  	if pos*2+1 < m.calcTreeWidth(height-1) {
  60  		m.traverseAndBuild(height-1, pos*2+1)
  61  	}
  62  }
  63  
  64  // NewMerkleBlock returns a new *wire.MsgMerkleBlock and an array of the matched transaction index numbers based on the
  65  // passed block and filter.
  66  func NewMerkleBlock(block *block.Block, filter *Filter) (*wire.MsgMerkleBlock, []uint32) {
  67  	numTx := uint32(len(block.Transactions()))
  68  	mBlock := merkleBlock{
  69  		numTx:       numTx,
  70  		allHashes:   make([]*chainhash.Hash, 0, numTx),
  71  		matchedBits: make([]byte, 0, numTx),
  72  	}
  73  	// Find and keep track of any transactions that match the filter.
  74  	var matchedIndices []uint32
  75  	for txIndex, tx := range block.Transactions() {
  76  		if filter.MatchTxAndUpdate(tx) {
  77  			mBlock.matchedBits = append(mBlock.matchedBits, 0x01)
  78  			matchedIndices = append(matchedIndices, uint32(txIndex))
  79  		} else {
  80  			mBlock.matchedBits = append(mBlock.matchedBits, 0x00)
  81  		}
  82  		mBlock.allHashes = append(mBlock.allHashes, tx.Hash())
  83  	}
  84  	// Calculate the number of merkle branches (height) in the tree.
  85  	height := uint32(0)
  86  	for mBlock.calcTreeWidth(height) > 1 {
  87  		height++
  88  	}
  89  	// Build the depth-first partial merkle tree.
  90  	mBlock.traverseAndBuild(height, 0)
  91  	// Create and return the merkle block.
  92  	msgMerkleBlock := wire.MsgMerkleBlock{
  93  		Header:       block.WireBlock().Header,
  94  		Transactions: mBlock.numTx,
  95  		Hashes:       make([]*chainhash.Hash, 0, len(mBlock.finalHashes)),
  96  		Flags:        make([]byte, (len(mBlock.bits)+7)/8),
  97  	}
  98  	for _, hash := range mBlock.finalHashes {
  99  		e := msgMerkleBlock.AddTxHash(hash)
 100  		if e != nil {
 101  			E.Ln(e)
 102  		}
 103  	}
 104  	for i := uint32(0); i < uint32(len(mBlock.bits)); i++ {
 105  		msgMerkleBlock.Flags[i/8] |= mBlock.bits[i] << (i % 8)
 106  	}
 107  	return &msgMerkleBlock, matchedIndices
 108  }
 109