negentropy.go raw

   1  package negentropy
   2  
   3  import (
   4  	"fmt"
   5  )
   6  
   7  // Negentropy handles the set reconciliation protocol.
   8  type Negentropy struct {
   9  	storage        Storage
  10  	frameSizeLimit int
  11  	idSize         int
  12  
  13  	// Result slices populated during Reconcile
  14  	havesList    []string // IDs we have that peer needs
  15  	haveNotsList []string // IDs peer has that we need
  16  
  17  	// Protocol state
  18  	isInitiator bool
  19  }
  20  
  21  // New creates a new Negentropy instance.
  22  func New(storage Storage, frameSizeLimit int) *Negentropy {
  23  	if frameSizeLimit < 4096 {
  24  		frameSizeLimit = 4096
  25  	}
  26  
  27  	return &Negentropy{
  28  		storage:        storage,
  29  		frameSizeLimit: frameSizeLimit,
  30  		idSize:         FullIDSize,
  31  	}
  32  }
  33  
  34  // Haves returns the IDs we have that the peer needs.
  35  func (n *Negentropy) Haves() []string {
  36  	return n.havesList
  37  }
  38  
  39  // HaveNots returns the IDs the peer has that we need.
  40  func (n *Negentropy) HaveNots() []string {
  41  	return n.haveNotsList
  42  }
  43  
  44  // Start creates the initial message to begin reconciliation (client side).
  45  func (n *Negentropy) Start() ([]byte, error) {
  46  	n.isInitiator = true
  47  
  48  	enc := NewEncoder(n.frameSizeLimit)
  49  
  50  	// Write protocol version
  51  	enc.WriteByte(ProtocolVersion)
  52  
  53  	// Initial range covers everything
  54  	n.splitRange(enc, 0, n.storage.Size(), MinBound(), MaxBound())
  55  
  56  	return enc.Bytes(), nil
  57  }
  58  
  59  // Reconcile processes an incoming message and generates a response.
  60  // Returns the response message and whether reconciliation is complete.
  61  func (n *Negentropy) Reconcile(msg []byte) (response []byte, complete bool, err error) {
  62  	if len(msg) == 0 {
  63  		return nil, false, ErrInvalidMessage
  64  	}
  65  
  66  	dec := NewDecoder(msg)
  67  
  68  	// Check protocol version
  69  	version, err := dec.ReadByte()
  70  	if err != nil {
  71  		return nil, false, err
  72  	}
  73  	if version != ProtocolVersion {
  74  		return nil, false, fmt.Errorf("unsupported protocol version: %x", version)
  75  	}
  76  
  77  	enc := NewEncoder(n.frameSizeLimit)
  78  	enc.WriteByte(ProtocolVersion)
  79  
  80  	// Process all ranges in the message
  81  	prevBound := MinBound()
  82  	prevIndex := 0
  83  	exceededFrameSize := false
  84  
  85  	for dec.HasMore() {
  86  		// Read upper bound of this range
  87  		upperBound, err := dec.ReadBound()
  88  		if err != nil {
  89  			return nil, false, fmt.Errorf("failed to read bound: %w", err)
  90  		}
  91  
  92  		// Read mode
  93  		mode, err := dec.ReadMode()
  94  		if err != nil {
  95  			return nil, false, fmt.Errorf("failed to read mode: %w", err)
  96  		}
  97  
  98  		// Find the range in our storage
  99  		lower := n.storage.FindLowerBound(prevIndex, n.storage.Size(), prevBound)
 100  		upper := n.storage.FindLowerBound(lower, n.storage.Size(), upperBound)
 101  
 102  		// If frame size exceeded, emit compact fingerprint for remaining ranges
 103  		// instead of detailed data. Using Fingerprint (not Skip) ensures the
 104  		// peer will re-examine these ranges in the next round.
 105  		if exceededFrameSize {
 106  			// Must still advance decoder past mode data
 107  			n.skipModeData(dec, mode)
 108  
 109  			// Emit fingerprint so peer continues reconciliation for this range
 110  			fp := n.storage.Fingerprint(lower, upper)
 111  			enc.WriteBound(upperBound)
 112  			enc.WriteMode(ModeFingerprint)
 113  			enc.WriteFingerprint(fp)
 114  
 115  			prevBound = upperBound
 116  			prevIndex = upper
 117  			continue
 118  		}
 119  
 120  		switch mode {
 121  		case ModeSkip:
 122  			// Range is in sync, skip it
 123  			n.skipRange(enc, upperBound)
 124  
 125  		case ModeFingerprint:
 126  			// Read their fingerprint
 127  			theirFP, err := dec.ReadFingerprint()
 128  			if err != nil {
 129  				return nil, false, fmt.Errorf("failed to read fingerprint: %w", err)
 130  			}
 131  
 132  			// Compare with our fingerprint
 133  			ourFP := n.storage.Fingerprint(lower, upper)
 134  
 135  			if ourFP == theirFP {
 136  				// Fingerprints match, skip this range
 137  				n.skipRange(enc, upperBound)
 138  			} else {
 139  				// Fingerprints differ, need to split or send IDs
 140  				n.splitRange(enc, lower, upper, prevBound, upperBound)
 141  			}
 142  
 143  		case ModeIdList:
 144  			// Read their ID count
 145  			numIds, err := dec.ReadVarInt()
 146  			if err != nil {
 147  				return nil, false, fmt.Errorf("failed to read id count: %w", err)
 148  			}
 149  
 150  			if n.isInitiator {
 151  				// Initiator: read their IDs, compare with ours, record diffs
 152  				theirIds := make(map[string]bool)
 153  				for i := uint64(0); i < numIds; i++ {
 154  					idBytes, err := dec.ReadBytes(n.idSize)
 155  					if err != nil {
 156  						return nil, false, fmt.Errorf("failed to read id: %w", err)
 157  					}
 158  					theirIds[encodeHex(idBytes)] = true
 159  				}
 160  
 161  				// Find differences
 162  				for _, item := range n.storage.Range(lower, upper) {
 163  					fullID := item.ID[:n.idSize*2] // hex is 2 chars per byte
 164  					if !theirIds[fullID] {
 165  						// We have it, they don't
 166  						n.havesList = append(n.havesList, item.ID)
 167  					}
 168  					delete(theirIds, fullID)
 169  				}
 170  
 171  				// Remaining IDs are ones they have that we don't
 172  				for id := range theirIds {
 173  					n.haveNotsList = append(n.haveNotsList, id)
 174  				}
 175  
 176  				// Initiator: skip this range (diffs already computed)
 177  				n.skipRange(enc, upperBound)
 178  			} else {
 179  				// Responder: read past their IDs to advance decoder position
 180  				for i := uint64(0); i < numIds; i++ {
 181  					_, err := dec.ReadBytes(n.idSize)
 182  					if err != nil {
 183  						return nil, false, fmt.Errorf("failed to read id: %w", err)
 184  					}
 185  				}
 186  
 187  				// Check if our range would exceed frame size as an ID list
 188  				numOurItems := upper - lower
 189  				if n.estimateIdListSize(numOurItems)+len(enc.Bytes()) > n.frameSizeLimit {
 190  					// Too large for ID list - use splitRange for compact fingerprint buckets
 191  					n.splitRange(enc, lower, upper, prevBound, upperBound)
 192  				} else {
 193  					// Fits in frame - send our own ID list so initiator can compute diffs
 194  					enc.WriteBound(upperBound)
 195  					enc.WriteMode(ModeIdList)
 196  					enc.WriteVarInt(uint64(numOurItems))
 197  					for _, item := range n.storage.Range(lower, upper) {
 198  						idBytes, _ := decodeHex(item.ID)
 199  						if len(idBytes) >= FullIDSize {
 200  							enc.WriteBytes(idBytes[:FullIDSize])
 201  						} else if len(idBytes) > 0 {
 202  							// Pad short IDs with zeros to FullIDSize
 203  							padded := make([]byte, FullIDSize)
 204  							copy(padded, idBytes)
 205  							enc.WriteBytes(padded)
 206  						}
 207  					}
 208  				}
 209  			}
 210  		}
 211  
 212  		// Check if we've exceeded the frame size limit after processing this range
 213  		if len(enc.Bytes()) > n.frameSizeLimit {
 214  			exceededFrameSize = true
 215  		}
 216  
 217  		prevBound = upperBound
 218  		prevIndex = upper
 219  	}
 220  
 221  	response = enc.Bytes()
 222  
 223  	// Check if reconciliation is complete
 224  	// Complete when response only contains version + all skips
 225  	complete = n.isResponseComplete(response)
 226  
 227  	return response, complete, nil
 228  }
 229  
 230  // skipModeData advances the decoder past the data for a given mode without
 231  // processing it. Used when frame size is exceeded and we need to defer ranges.
 232  func (n *Negentropy) skipModeData(dec *Decoder, mode Mode) {
 233  	switch mode {
 234  	case ModeSkip:
 235  		// No additional data
 236  	case ModeFingerprint:
 237  		// Read and discard fingerprint (16 bytes)
 238  		dec.ReadBytes(DefaultIDSize)
 239  	case ModeIdList:
 240  		// Read count, then skip that many IDs
 241  		numIds, err := dec.ReadVarInt()
 242  		if err != nil {
 243  			return
 244  		}
 245  		for i := uint64(0); i < numIds; i++ {
 246  			dec.ReadBytes(n.idSize)
 247  		}
 248  	}
 249  }
 250  
 251  // skipRange writes a skip mode for the given bound.
 252  func (n *Negentropy) skipRange(enc *Encoder, upperBound Bound) {
 253  	enc.WriteBound(upperBound)
 254  	enc.WriteMode(ModeSkip)
 255  }
 256  
 257  // splitRange splits a range into buckets and writes fingerprints or ID lists.
 258  func (n *Negentropy) splitRange(enc *Encoder, lower, upper int, lowerBound, upperBound Bound) {
 259  	numItems := upper - lower
 260  
 261  	if numItems == 0 {
 262  		// Empty range, send as ID list with 0 items
 263  		enc.WriteBound(upperBound)
 264  		enc.WriteMode(ModeIdList)
 265  		enc.WriteVarInt(0)
 266  		return
 267  	}
 268  
 269  	// For small ranges, send full ID list if it fits in the remaining frame space
 270  	if numItems <= 2 || n.estimateIdListSize(numItems) < n.frameSizeLimit/10 {
 271  		// Also check cumulative frame size before writing ID list
 272  		if n.estimateIdListSize(numItems)+len(enc.Bytes()) <= n.frameSizeLimit {
 273  			enc.WriteBound(upperBound)
 274  			enc.WriteMode(ModeIdList)
 275  			enc.WriteVarInt(uint64(numItems))
 276  			for _, item := range n.storage.Range(lower, upper) {
 277  				idBytes, _ := decodeHex(item.ID)
 278  				if len(idBytes) >= FullIDSize {
 279  					enc.WriteBytes(idBytes[:FullIDSize])
 280  				} else if len(idBytes) > 0 {
 281  					// Pad short IDs with zeros to FullIDSize
 282  					padded := make([]byte, FullIDSize)
 283  					copy(padded, idBytes)
 284  					enc.WriteBytes(padded)
 285  				}
 286  			}
 287  			return
 288  		}
 289  		// ID list would exceed frame - fall through to fingerprint buckets
 290  	}
 291  
 292  	// For larger ranges, split into buckets with fingerprints
 293  	numBuckets := n.calculateBuckets(numItems)
 294  	itemsPerBucket := numItems / numBuckets
 295  
 296  	for i := 0; i < numBuckets; i++ {
 297  		bucketStart := lower + i*itemsPerBucket
 298  		bucketEnd := lower + (i+1)*itemsPerBucket
 299  		if i == numBuckets-1 {
 300  			bucketEnd = upper // Last bucket gets remainder
 301  		}
 302  
 303  		var bucketBound Bound
 304  		if i == numBuckets-1 {
 305  			// Last bucket must use the original upperBound to maintain
 306  			// range alignment with the peer. Using GetBound(bucketEnd)
 307  			// can produce a different bound when the peer's boundary
 308  			// event doesn't exist in our storage, causing range
 309  			// misalignment and false fingerprint mismatches.
 310  			bucketBound = upperBound
 311  		} else {
 312  			bucketBound = n.storage.GetBound(bucketEnd)
 313  		}
 314  
 315  		enc.WriteBound(bucketBound)
 316  		enc.WriteMode(ModeFingerprint)
 317  		fp := n.storage.Fingerprint(bucketStart, bucketEnd)
 318  		enc.WriteFingerprint(fp)
 319  	}
 320  }
 321  
 322  // estimateIdListSize estimates the encoded size of an ID list.
 323  func (n *Negentropy) estimateIdListSize(numItems int) int {
 324  	// Bound + mode + count varint + (idSize * numItems)
 325  	return 20 + VarIntSize(uint64(numItems)) + n.idSize*numItems
 326  }
 327  
 328  // calculateBuckets determines how many buckets to split a range into.
 329  func (n *Negentropy) calculateBuckets(numItems int) int {
 330  	// Use square root heuristic, clamped to reasonable bounds
 331  	buckets := 1
 332  	for buckets*buckets < numItems {
 333  		buckets++
 334  	}
 335  	if buckets < 2 {
 336  		buckets = 2
 337  	}
 338  	if buckets > 16 {
 339  		buckets = 16
 340  	}
 341  	return buckets
 342  }
 343  
 344  // isResponseComplete checks if the response indicates reconciliation is complete.
 345  func (n *Negentropy) isResponseComplete(response []byte) bool {
 346  	if len(response) <= 1 {
 347  		return true
 348  	}
 349  
 350  	dec := NewDecoder(response)
 351  
 352  	// Skip version
 353  	_, err := dec.ReadByte()
 354  	if err != nil {
 355  		return false
 356  	}
 357  
 358  	// Check if all ranges are skips
 359  	for dec.HasMore() {
 360  		_, err := dec.ReadBound()
 361  		if err != nil {
 362  			return false
 363  		}
 364  
 365  		mode, err := dec.ReadMode()
 366  		if err != nil {
 367  			return false
 368  		}
 369  
 370  		if mode != ModeSkip {
 371  			return false
 372  		}
 373  	}
 374  
 375  	return true
 376  }
 377  
 378  // Close is a no-op (retained for API compatibility).
 379  func (n *Negentropy) Close() {
 380  }
 381  
 382  // CollectHaves returns all IDs we have that the peer needs, and resets the list.
 383  func (n *Negentropy) CollectHaves() []string {
 384  	result := n.havesList
 385  	n.havesList = nil
 386  	return result
 387  }
 388  
 389  // CollectHaveNots returns all IDs the peer has that we need, and resets the list.
 390  func (n *Negentropy) CollectHaveNots() []string {
 391  	result := n.haveNotsList
 392  	n.haveNotsList = nil
 393  	return result
 394  }
 395