subnet_pool.go raw

   1  package wireguard
   2  
   3  import (
   4  	"crypto/sha256"
   5  	"encoding/binary"
   6  	"fmt"
   7  	"net/netip"
   8  	"sync"
   9  
  10  	"lukechampine.com/frand"
  11  )
  12  
  13  // Subnet represents a /31 point-to-point subnet.
  14  type Subnet struct {
  15  	ServerIP netip.Addr // Even address (server side)
  16  	ClientIP netip.Addr // Odd address (client side)
  17  }
  18  
  19  // SubnetPool manages deterministic /31 subnet generation from a seed.
  20  // Given the same seed and sequence number, the same subnet is always generated.
  21  type SubnetPool struct {
  22  	seed       [32]byte       // Random seed for deterministic generation
  23  	basePrefix netip.Prefix   // e.g., 10.0.0.0/8
  24  	maxSeq     uint32         // Current highest sequence number
  25  	assigned   map[string]uint32 // Client pubkey hex -> sequence number
  26  	mu         sync.RWMutex
  27  }
  28  
  29  // NewSubnetPool creates a subnet pool with a new random seed.
  30  func NewSubnetPool(baseNetwork string) (*SubnetPool, error) {
  31  	prefix, err := netip.ParsePrefix(baseNetwork)
  32  	if err != nil {
  33  		return nil, fmt.Errorf("invalid base network: %w", err)
  34  	}
  35  
  36  	var seed [32]byte
  37  	frand.Read(seed[:])
  38  
  39  	return &SubnetPool{
  40  		seed:       seed,
  41  		basePrefix: prefix,
  42  		maxSeq:     0,
  43  		assigned:   make(map[string]uint32),
  44  	}, nil
  45  }
  46  
  47  // NewSubnetPoolWithSeed creates a subnet pool with an existing seed.
  48  func NewSubnetPoolWithSeed(baseNetwork string, seed []byte) (*SubnetPool, error) {
  49  	prefix, err := netip.ParsePrefix(baseNetwork)
  50  	if err != nil {
  51  		return nil, fmt.Errorf("invalid base network: %w", err)
  52  	}
  53  
  54  	if len(seed) != 32 {
  55  		return nil, fmt.Errorf("seed must be 32 bytes, got %d", len(seed))
  56  	}
  57  
  58  	pool := &SubnetPool{
  59  		basePrefix: prefix,
  60  		maxSeq:     0,
  61  		assigned:   make(map[string]uint32),
  62  	}
  63  	copy(pool.seed[:], seed)
  64  
  65  	return pool, nil
  66  }
  67  
  68  // Seed returns the pool's seed for persistence.
  69  func (p *SubnetPool) Seed() []byte {
  70  	return p.seed[:]
  71  }
  72  
  73  // deriveSubnet deterministically generates a /31 subnet from seed + sequence.
  74  func (p *SubnetPool) deriveSubnet(seq uint32) Subnet {
  75  	// Hash seed + sequence to get deterministic randomness
  76  	h := sha256.New()
  77  	h.Write(p.seed[:])
  78  	binary.Write(h, binary.BigEndian, seq)
  79  	hash := h.Sum(nil)
  80  
  81  	// Use first 4 bytes as offset within the prefix
  82  	offset := binary.BigEndian.Uint32(hash[:4])
  83  
  84  	// Calculate available address space
  85  	bits := p.basePrefix.Bits()
  86  	availableBits := uint32(32 - bits)
  87  	maxOffset := uint32(1) << availableBits
  88  
  89  	// Make offset even (for /31 alignment) and within range
  90  	offset = (offset % (maxOffset / 2)) * 2
  91  
  92  	// Calculate server IP (even) and client IP (odd)
  93  	baseAddr := p.basePrefix.Addr()
  94  	baseBytes := baseAddr.As4()
  95  	baseVal := uint32(baseBytes[0])<<24 | uint32(baseBytes[1])<<16 |
  96  		uint32(baseBytes[2])<<8 | uint32(baseBytes[3])
  97  
  98  	serverVal := baseVal + offset
  99  	clientVal := serverVal + 1
 100  
 101  	serverBytes := [4]byte{
 102  		byte(serverVal >> 24), byte(serverVal >> 16),
 103  		byte(serverVal >> 8), byte(serverVal),
 104  	}
 105  	clientBytes := [4]byte{
 106  		byte(clientVal >> 24), byte(clientVal >> 16),
 107  		byte(clientVal >> 8), byte(clientVal),
 108  	}
 109  
 110  	return Subnet{
 111  		ServerIP: netip.AddrFrom4(serverBytes),
 112  		ClientIP: netip.AddrFrom4(clientBytes),
 113  	}
 114  }
 115  
 116  // ServerIPs returns server-side IPs for sequences 0 to maxSeq (for netstack).
 117  func (p *SubnetPool) ServerIPs() []netip.Addr {
 118  	p.mu.RLock()
 119  	defer p.mu.RUnlock()
 120  
 121  	if p.maxSeq == 0 {
 122  		return nil
 123  	}
 124  
 125  	ips := make([]netip.Addr, p.maxSeq)
 126  	for seq := uint32(0); seq < p.maxSeq; seq++ {
 127  		subnet := p.deriveSubnet(seq)
 128  		ips[seq] = subnet.ServerIP
 129  	}
 130  	return ips
 131  }
 132  
 133  // GetSubnet returns the subnet for a client, or nil if not assigned.
 134  func (p *SubnetPool) GetSubnet(clientPubkeyHex string) *Subnet {
 135  	p.mu.RLock()
 136  	defer p.mu.RUnlock()
 137  
 138  	if seq, ok := p.assigned[clientPubkeyHex]; ok {
 139  		subnet := p.deriveSubnet(seq)
 140  		return &subnet
 141  	}
 142  	return nil
 143  }
 144  
 145  // GetSequence returns the sequence number for a client, or -1 if not assigned.
 146  func (p *SubnetPool) GetSequence(clientPubkeyHex string) int {
 147  	p.mu.RLock()
 148  	defer p.mu.RUnlock()
 149  
 150  	if seq, ok := p.assigned[clientPubkeyHex]; ok {
 151  		return int(seq)
 152  	}
 153  	return -1
 154  }
 155  
 156  // RestoreAllocation restores a previously saved allocation.
 157  func (p *SubnetPool) RestoreAllocation(clientPubkeyHex string, seq uint32) {
 158  	p.mu.Lock()
 159  	defer p.mu.Unlock()
 160  
 161  	p.assigned[clientPubkeyHex] = seq
 162  	if seq >= p.maxSeq {
 163  		p.maxSeq = seq + 1
 164  	}
 165  }
 166  
 167  // MaxSequence returns the current max sequence number.
 168  func (p *SubnetPool) MaxSequence() uint32 {
 169  	p.mu.RLock()
 170  	defer p.mu.RUnlock()
 171  	return p.maxSeq
 172  }
 173  
 174  // AllocatedCount returns the number of allocated subnets.
 175  func (p *SubnetPool) AllocatedCount() int {
 176  	p.mu.RLock()
 177  	defer p.mu.RUnlock()
 178  	return len(p.assigned)
 179  }
 180  
 181  // SubnetForSequence returns the subnet for a given sequence number.
 182  func (p *SubnetPool) SubnetForSequence(seq uint32) Subnet {
 183  	return p.deriveSubnet(seq)
 184  }
 185