sample.go raw
1 package ring
2
3 import (
4 "crypto/rand"
5 "encoding/binary"
6 "io"
7 )
8
9 // UniformPoly samples a polynomial with coefficients uniformly random in [0, Q).
10 func UniformPoly(p Params) *Poly {
11 return UniformPolyFrom(p, rand.Reader)
12 }
13
14 // UniformPolyFrom samples a uniform polynomial using the given randomness source.
15 // Uses rejection sampling to avoid modular bias.
16 //
17 // For efficiency, reads entropy in bulk and uses the largest rejection bound
18 // that fits in the sample width. For q=257 with 16-bit samples: accept if
19 // v < floor(65536/257)*257 = 65535, then v % 257. Acceptance rate ~99.998%.
20 func UniformPolyFrom(p Params, rng io.Reader) *Poly {
21 poly := New(p)
22 q := p.Q
23
24 if q <= (1 << 16) {
25 // 16-bit sampling path (q fits in uint16).
26 rejBound := uint32((65536 / uint32(q)) * uint32(q)) // largest multiple of q ≤ 65536
27
28 // Batch read: need ~N samples, each 2 bytes. Over-allocate slightly
29 // for the rare rejection.
30 batchSize := p.N + 8 // small margin for rejections
31 buf := make([]byte, batchSize*2)
32 if _, err := io.ReadFull(rng, buf); err != nil {
33 panic("ring: randomness source failed: " + err.Error())
34 }
35 pos := 0
36 for i := range poly.Coeffs {
37 for {
38 if pos+2 > len(buf) {
39 // Ran out of batch (very rare). Refill.
40 buf = make([]byte, 16)
41 io.ReadFull(rng, buf)
42 pos = 0
43 }
44 v := uint32(binary.LittleEndian.Uint16(buf[pos : pos+2]))
45 pos += 2
46 if v < rejBound {
47 poly.Coeffs[i] = v % q
48 break
49 }
50 }
51 }
52 } else {
53 // 32-bit sampling path (q > 16 bits).
54 rejBound := (uint64(1)<<32 / uint64(q)) * uint64(q)
55
56 batchSize := p.N + 8
57 buf := make([]byte, batchSize*4)
58 if _, err := io.ReadFull(rng, buf); err != nil {
59 panic("ring: randomness source failed: " + err.Error())
60 }
61 pos := 0
62 for i := range poly.Coeffs {
63 for {
64 if pos+4 > len(buf) {
65 buf = make([]byte, 32)
66 io.ReadFull(rng, buf)
67 pos = 0
68 }
69 v := uint64(binary.LittleEndian.Uint32(buf[pos : pos+4]))
70 pos += 4
71 if v < rejBound {
72 poly.Coeffs[i] = uint32(v % uint64(q))
73 break
74 }
75 }
76 }
77 }
78 return poly
79 }
80
81 // TernaryPoly samples a polynomial with coefficients in {-1, 0, 1}.
82 // Each coefficient is 0 with probability 1/2, and ±1 with probability 1/4 each.
83 func TernaryPoly(p Params) *Poly {
84 return TernaryPolyFrom(p, rand.Reader)
85 }
86
87 // TernaryPolyFrom samples a ternary polynomial from the given source.
88 func TernaryPolyFrom(p Params, rng io.Reader) *Poly {
89 poly := New(p)
90 buf := make([]byte, p.N)
91 if _, err := io.ReadFull(rng, buf); err != nil {
92 panic("ring: randomness source failed: " + err.Error())
93 }
94 for i := range poly.Coeffs {
95 b := buf[i]
96 switch {
97 case b < 64: // probability 1/4
98 poly.Coeffs[i] = p.Q - 1 // -1 mod Q
99 case b < 128: // probability 1/4
100 poly.Coeffs[i] = 1
101 default: // probability 1/2
102 poly.Coeffs[i] = 0
103 }
104 }
105 return poly
106 }
107
108 // SmallPoly samples a polynomial with coefficients in [-bound, bound].
109 // Uses centered representation: negative values stored as Q - |v|.
110 func SmallPoly(p Params, bound uint32) *Poly {
111 return SmallPolyFrom(p, bound, rand.Reader)
112 }
113
114 // SmallPolyFrom samples a small polynomial from the given source.
115 func SmallPolyFrom(p Params, bound uint32, rng io.Reader) *Poly {
116 poly := New(p)
117 width := 2*bound + 1
118 buf := make([]byte, 2)
119 for i := range poly.Coeffs {
120 for {
121 if _, err := io.ReadFull(rng, buf); err != nil {
122 panic("ring: randomness source failed: " + err.Error())
123 }
124 v := uint32(binary.LittleEndian.Uint16(buf))
125 if v < (65535/width)*width { // rejection sampling for uniform
126 v = v % width
127 if v <= bound {
128 poly.Coeffs[i] = v
129 } else {
130 poly.Coeffs[i] = p.Q - (width - v)
131 }
132 break
133 }
134 }
135 }
136 return poly
137 }
138
139 // CBDPoly samples from a centered binomial distribution with parameter eta.
140 // CBD_eta produces coefficients in [-eta, eta].
141 // This is the noise distribution used in Kyber/ML-KEM.
142 func CBDPoly(p Params, eta int) *Poly {
143 return CBDPolyFrom(p, eta, rand.Reader)
144 }
145
146 // CBDPolyFrom samples CBD from the given source.
147 func CBDPolyFrom(p Params, eta int, rng io.Reader) *Poly {
148 poly := New(p)
149 bytesNeeded := (eta * 2 * p.N) / 8
150 buf := make([]byte, bytesNeeded)
151 if _, err := io.ReadFull(rng, buf); err != nil {
152 panic("ring: randomness source failed: " + err.Error())
153 }
154
155 bitIdx := 0
156 getBit := func() uint32 {
157 bytePos := bitIdx / 8
158 bitPos := uint(bitIdx % 8)
159 bitIdx++
160 return uint32((buf[bytePos] >> bitPos) & 1)
161 }
162
163 for i := range poly.Coeffs {
164 var a, b uint32
165 for j := range eta {
166 _ = j
167 a += getBit()
168 }
169 for j := range eta {
170 _ = j
171 b += getBit()
172 }
173 if a >= b {
174 poly.Coeffs[i] = a - b
175 } else {
176 poly.Coeffs[i] = p.Q - (b - a)
177 }
178 }
179 return poly
180 }
181
182 // Norm computes the infinity norm (max absolute coefficient) of a polynomial.
183 // Coefficients are interpreted as centered: values > Q/2 are negative.
184 func Norm(a *Poly) uint32 {
185 q := a.params.Q
186 half := q / 2
187 maxVal := uint32(0)
188 for _, c := range a.Coeffs {
189 v := c
190 if v > half {
191 v = q - v // |c| for negative values
192 }
193 if v > maxVal {
194 maxVal = v
195 }
196 }
197 return maxVal
198 }
199
200 // Serialize packs polynomial coefficients into a byte slice.
201 // Each coefficient is stored as ceil(log2(Q)) bits, little-endian packed.
202 func Serialize(a *Poly) []byte {
203 q := a.params.Q
204 bits := 0
205 for v := q - 1; v > 0; v >>= 1 {
206 bits++
207 }
208
209 totalBits := bits * len(a.Coeffs)
210 out := make([]byte, (totalBits+7)/8)
211
212 bitPos := 0
213 for _, c := range a.Coeffs {
214 for b := 0; b < bits; b++ {
215 if c&(1<<uint(b)) != 0 {
216 out[bitPos/8] |= 1 << uint(bitPos%8)
217 }
218 bitPos++
219 }
220 }
221 return out
222 }
223
224 // Deserialize unpacks a byte slice into a polynomial.
225 func Deserialize(p Params, data []byte) *Poly {
226 q := p.Q
227 bits := 0
228 for v := q - 1; v > 0; v >>= 1 {
229 bits++
230 }
231
232 poly := New(p)
233 bitPos := 0
234 for i := range poly.Coeffs {
235 var v uint32
236 for b := 0; b < bits; b++ {
237 if bitPos/8 < len(data) && data[bitPos/8]&(1<<uint(bitPos%8)) != 0 {
238 v |= 1 << uint(b)
239 }
240 bitPos++
241 }
242 if v >= q {
243 v = q - 1 // clamp (shouldn't happen with well-formed data)
244 }
245 poly.Coeffs[i] = v
246 }
247 return poly
248 }
249