subnet_pool.go raw
1 package wireguard
2
3 import (
4 "crypto/sha256"
5 "encoding/binary"
6 "fmt"
7 "net/netip"
8 "sync"
9
10 "lukechampine.com/frand"
11 )
12
13 // Subnet represents a /31 point-to-point subnet.
14 type Subnet struct {
15 ServerIP netip.Addr // Even address (server side)
16 ClientIP netip.Addr // Odd address (client side)
17 }
18
19 // SubnetPool manages deterministic /31 subnet generation from a seed.
20 // Given the same seed and sequence number, the same subnet is always generated.
21 type SubnetPool struct {
22 seed [32]byte // Random seed for deterministic generation
23 basePrefix netip.Prefix // e.g., 10.0.0.0/8
24 maxSeq uint32 // Current highest sequence number
25 assigned map[string]uint32 // Client pubkey hex -> sequence number
26 mu sync.RWMutex
27 }
28
29 // NewSubnetPool creates a subnet pool with a new random seed.
30 func NewSubnetPool(baseNetwork string) (*SubnetPool, error) {
31 prefix, err := netip.ParsePrefix(baseNetwork)
32 if err != nil {
33 return nil, fmt.Errorf("invalid base network: %w", err)
34 }
35
36 var seed [32]byte
37 frand.Read(seed[:])
38
39 return &SubnetPool{
40 seed: seed,
41 basePrefix: prefix,
42 maxSeq: 0,
43 assigned: make(map[string]uint32),
44 }, nil
45 }
46
47 // NewSubnetPoolWithSeed creates a subnet pool with an existing seed.
48 func NewSubnetPoolWithSeed(baseNetwork string, seed []byte) (*SubnetPool, error) {
49 prefix, err := netip.ParsePrefix(baseNetwork)
50 if err != nil {
51 return nil, fmt.Errorf("invalid base network: %w", err)
52 }
53
54 if len(seed) != 32 {
55 return nil, fmt.Errorf("seed must be 32 bytes, got %d", len(seed))
56 }
57
58 pool := &SubnetPool{
59 basePrefix: prefix,
60 maxSeq: 0,
61 assigned: make(map[string]uint32),
62 }
63 copy(pool.seed[:], seed)
64
65 return pool, nil
66 }
67
68 // Seed returns the pool's seed for persistence.
69 func (p *SubnetPool) Seed() []byte {
70 return p.seed[:]
71 }
72
73 // deriveSubnet deterministically generates a /31 subnet from seed + sequence.
74 func (p *SubnetPool) deriveSubnet(seq uint32) Subnet {
75 // Hash seed + sequence to get deterministic randomness
76 h := sha256.New()
77 h.Write(p.seed[:])
78 binary.Write(h, binary.BigEndian, seq)
79 hash := h.Sum(nil)
80
81 // Use first 4 bytes as offset within the prefix
82 offset := binary.BigEndian.Uint32(hash[:4])
83
84 // Calculate available address space
85 bits := p.basePrefix.Bits()
86 availableBits := uint32(32 - bits)
87 maxOffset := uint32(1) << availableBits
88
89 // Make offset even (for /31 alignment) and within range
90 offset = (offset % (maxOffset / 2)) * 2
91
92 // Calculate server IP (even) and client IP (odd)
93 baseAddr := p.basePrefix.Addr()
94 baseBytes := baseAddr.As4()
95 baseVal := uint32(baseBytes[0])<<24 | uint32(baseBytes[1])<<16 |
96 uint32(baseBytes[2])<<8 | uint32(baseBytes[3])
97
98 serverVal := baseVal + offset
99 clientVal := serverVal + 1
100
101 serverBytes := [4]byte{
102 byte(serverVal >> 24), byte(serverVal >> 16),
103 byte(serverVal >> 8), byte(serverVal),
104 }
105 clientBytes := [4]byte{
106 byte(clientVal >> 24), byte(clientVal >> 16),
107 byte(clientVal >> 8), byte(clientVal),
108 }
109
110 return Subnet{
111 ServerIP: netip.AddrFrom4(serverBytes),
112 ClientIP: netip.AddrFrom4(clientBytes),
113 }
114 }
115
116 // ServerIPs returns server-side IPs for sequences 0 to maxSeq (for netstack).
117 func (p *SubnetPool) ServerIPs() []netip.Addr {
118 p.mu.RLock()
119 defer p.mu.RUnlock()
120
121 if p.maxSeq == 0 {
122 return nil
123 }
124
125 ips := make([]netip.Addr, p.maxSeq)
126 for seq := uint32(0); seq < p.maxSeq; seq++ {
127 subnet := p.deriveSubnet(seq)
128 ips[seq] = subnet.ServerIP
129 }
130 return ips
131 }
132
133 // GetSubnet returns the subnet for a client, or nil if not assigned.
134 func (p *SubnetPool) GetSubnet(clientPubkeyHex string) *Subnet {
135 p.mu.RLock()
136 defer p.mu.RUnlock()
137
138 if seq, ok := p.assigned[clientPubkeyHex]; ok {
139 subnet := p.deriveSubnet(seq)
140 return &subnet
141 }
142 return nil
143 }
144
145 // GetSequence returns the sequence number for a client, or -1 if not assigned.
146 func (p *SubnetPool) GetSequence(clientPubkeyHex string) int {
147 p.mu.RLock()
148 defer p.mu.RUnlock()
149
150 if seq, ok := p.assigned[clientPubkeyHex]; ok {
151 return int(seq)
152 }
153 return -1
154 }
155
156 // RestoreAllocation restores a previously saved allocation.
157 func (p *SubnetPool) RestoreAllocation(clientPubkeyHex string, seq uint32) {
158 p.mu.Lock()
159 defer p.mu.Unlock()
160
161 p.assigned[clientPubkeyHex] = seq
162 if seq >= p.maxSeq {
163 p.maxSeq = seq + 1
164 }
165 }
166
167 // MaxSequence returns the current max sequence number.
168 func (p *SubnetPool) MaxSequence() uint32 {
169 p.mu.RLock()
170 defer p.mu.RUnlock()
171 return p.maxSeq
172 }
173
174 // AllocatedCount returns the number of allocated subnets.
175 func (p *SubnetPool) AllocatedCount() int {
176 p.mu.RLock()
177 defer p.mu.RUnlock()
178 return len(p.assigned)
179 }
180
181 // SubnetForSequence returns the subnet for a given sequence number.
182 func (p *SubnetPool) SubnetForSequence(seq uint32) Subnet {
183 return p.deriveSubnet(seq)
184 }
185