poly.go raw
1 package common
2
3 // An element of our base ring R which are polynomials over ℤ_q
4 // modulo the equation Xᴺ = -1, where q=3329 and N=256.
5 //
6 // This type is also used to store NTT-transformed polynomials,
7 // see Poly.NTT().
8 //
9 // Coefficients aren't always reduced. See Normalize().
10 type Poly [N]int16
11
12 // Sets p to a + b. Does not normalize coefficients.
13 func (p *Poly) addGeneric(a, b *Poly) {
14 for i := 0; i < N; i++ {
15 p[i] = a[i] + b[i]
16 }
17 }
18
19 // Sets p to a - b. Does not normalize coefficients.
20 func (p *Poly) subGeneric(a, b *Poly) {
21 for i := 0; i < N; i++ {
22 p[i] = a[i] - b[i]
23 }
24 }
25
26 // Almost normalizes coefficients.
27 //
28 // Ensures each coefficient is in {0, …, q}.
29 func (p *Poly) barrettReduceGeneric() {
30 for i := 0; i < N; i++ {
31 p[i] = barrettReduce(p[i])
32 }
33 }
34
35 // Normalizes coefficients.
36 //
37 // Ensures each coefficient is in {0, …, q-1}.
38 func (p *Poly) normalizeGeneric() {
39 for i := 0; i < N; i++ {
40 p[i] = csubq(barrettReduce(p[i]))
41 }
42 }
43
44 // Multiplies p in-place by the Montgomery factor 2¹⁶.
45 //
46 // Coefficients of p can be arbitrary. Resulting coefficients are bounded
47 // in absolute value by q.
48 func (p *Poly) ToMont() {
49 for i := 0; i < N; i++ {
50 p[i] = toMont(p[i])
51 }
52 }
53
54 // Sets p to the "pointwise" multiplication of a and b.
55 //
56 // That is: InvNTT(p) = InvNTT(a) * InvNTT(b). Assumes a and b are in
57 // Montgomery form. Products between coefficients of a and b must be strictly
58 // bounded in absolute value by 2¹⁵q. p will be in Montgomery form and
59 // bounded in absolute value by 2q.
60 //
61 // Requires a and b to be in "tangled" order, see Tangle(). p will be in
62 // tangled order as well.
63 func (p *Poly) mulHatGeneric(a, b *Poly) {
64 // Recall from the discussion in NTT(), that a transformed polynomial is
65 // an element of ℤ_q[x]/(x²-ζ) x … x ℤ_q[x]/(x²+ζ¹²⁷);
66 // that is: 128 degree-one polynomials instead of simply 256 elements
67 // from ℤ_q as in the regular NTT. So instead of pointwise multiplication,
68 // we multiply the 128 pairs of degree-one polynomials modulo the
69 // right equation:
70 //
71 // (a₁ + a₂x)(b₁ + b₂x) = a₁b₁ + a₂b₂ζ' + (a₁b₂ + a₂b₁)x,
72 //
73 // where ζ' is the appropriate power of ζ.
74
75 k := 64
76 for i := 0; i < N; i += 4 {
77 zeta := int32(Zetas[k])
78 k++
79
80 p0 := montReduce(int32(a[i+1]) * int32(b[i+1]))
81 p0 = montReduce(int32(p0) * zeta)
82 p0 += montReduce(int32(a[i]) * int32(b[i]))
83
84 p1 := montReduce(int32(a[i]) * int32(b[i+1]))
85 p1 += montReduce(int32(a[i+1]) * int32(b[i]))
86
87 p[i] = p0
88 p[i+1] = p1
89
90 p2 := montReduce(int32(a[i+3]) * int32(b[i+3]))
91 p2 = -montReduce(int32(p2) * zeta)
92 p2 += montReduce(int32(a[i+2]) * int32(b[i+2]))
93
94 p3 := montReduce(int32(a[i+2]) * int32(b[i+3]))
95 p3 += montReduce(int32(a[i+3]) * int32(b[i+2]))
96
97 p[i+2] = p2
98 p[i+3] = p3
99 }
100 }
101
102 // Packs p into buf. buf should be of length PolySize.
103 //
104 // Assumes p is normalized (and not just Barrett reduced) and "tangled",
105 // see Tangle().
106 func (p *Poly) Pack(buf []byte) {
107 q := *p
108 q.Detangle()
109 for i := 0; i < 128; i++ {
110 t0 := q[2*i]
111 t1 := q[2*i+1]
112 buf[3*i] = byte(t0)
113 buf[3*i+1] = byte(t0>>8) | byte(t1<<4)
114 buf[3*i+2] = byte(t1 >> 4)
115 }
116 }
117
118 // Unpacks p from buf.
119 //
120 // buf should be of length PolySize. p will be "tangled", see Detangle().
121 //
122 // p will not be normalized; instead 0 ≤ p[i] < 4096.
123 func (p *Poly) Unpack(buf []byte) {
124 for i := 0; i < 128; i++ {
125 p[2*i] = int16(buf[3*i]) | ((int16(buf[3*i+1]) << 8) & 0xfff)
126 p[2*i+1] = int16(buf[3*i+1]>>4) | (int16(buf[3*i+2]) << 4)
127 }
128 p.Tangle()
129 }
130
131 // Set p to Decompress_q(m, 1).
132 //
133 // p will be normalized. m has to be of PlaintextSize.
134 func (p *Poly) DecompressMessage(m []byte) {
135 // Decompress_q(x, 1) = ⌈xq/2⌋ = ⌊xq/2+½⌋ = (xq+1) >> 1 and so
136 // Decompress_q(0, 1) = 0 and Decompress_q(1, 1) = (q+1)/2.
137 for i := 0; i < 32; i++ {
138 for j := 0; j < 8; j++ {
139 bit := (m[i] >> uint(j)) & 1
140
141 // Set coefficient to either 0 or (q+1)/2 depending on the bit.
142 p[8*i+j] = -int16(bit) & ((Q + 1) / 2)
143 }
144 }
145 }
146
147 // Writes Compress_q(p, 1) to m.
148 //
149 // Assumes p is normalized. m has to be of length at least PlaintextSize.
150 func (p *Poly) CompressMessageTo(m []byte) {
151 // Compress_q(x, 1) is 1 on {833, …, 2496} and zero elsewhere.
152 for i := 0; i < 32; i++ {
153 m[i] = 0
154 for j := 0; j < 8; j++ {
155 x := 1664 - p[8*i+j]
156 // With the previous substitution, we want to return 1 if
157 // and only if x is in {831, …, -832}.
158 x = (x >> 15) ^ x
159 // Note (x >> 15)ˣ if x≥0 and -x-1 otherwise. Thus now we want
160 // to return 1 iff x ≤ 831, ie. x - 832 < 0.
161 x -= 832
162 m[i] |= ((byte(x >> 15)) & 1) << uint(j)
163 }
164 }
165 }
166
167 // Set p to Decompress_q(m, 1).
168 //
169 // Assumes d is in {4, 5, 10, 11}. p will be normalized.
170 func (p *Poly) Decompress(m []byte, d int) {
171 // Decompress_q(x, d) = ⌈(q/2ᵈ)x⌋
172 // = ⌊(q/2ᵈ)x+½⌋
173 // = ⌊(qx + 2ᵈ⁻¹)/2ᵈ⌋
174 // = (qx + (1<<(d-1))) >> d
175 switch d {
176 case 4:
177 for i := 0; i < N/2; i++ {
178 p[2*i] = int16(((1 << 3) +
179 uint32(m[i]&15)*uint32(Q)) >> 4)
180 p[2*i+1] = int16(((1 << 3) +
181 uint32(m[i]>>4)*uint32(Q)) >> 4)
182 }
183 case 5:
184 var t [8]uint16
185 idx := 0
186 for i := 0; i < N/8; i++ {
187 t[0] = uint16(m[idx])
188 t[1] = (uint16(m[idx]) >> 5) | (uint16(m[idx+1] << 3))
189 t[2] = uint16(m[idx+1]) >> 2
190 t[3] = (uint16(m[idx+1]) >> 7) | (uint16(m[idx+2] << 1))
191 t[4] = (uint16(m[idx+2]) >> 4) | (uint16(m[idx+3] << 4))
192 t[5] = uint16(m[idx+3]) >> 1
193 t[6] = (uint16(m[idx+3]) >> 6) | (uint16(m[idx+4] << 2))
194 t[7] = uint16(m[idx+4]) >> 3
195
196 for j := 0; j < 8; j++ {
197 p[8*i+j] = int16(((1 << 4) +
198 uint32(t[j]&((1<<5)-1))*uint32(Q)) >> 5)
199 }
200
201 idx += 5
202 }
203
204 case 10:
205 var t [4]uint16
206 idx := 0
207 for i := 0; i < N/4; i++ {
208 t[0] = uint16(m[idx]) | (uint16(m[idx+1]) << 8)
209 t[1] = (uint16(m[idx+1]) >> 2) | (uint16(m[idx+2]) << 6)
210 t[2] = (uint16(m[idx+2]) >> 4) | (uint16(m[idx+3]) << 4)
211 t[3] = (uint16(m[idx+3]) >> 6) | (uint16(m[idx+4]) << 2)
212
213 for j := 0; j < 4; j++ {
214 p[4*i+j] = int16(((1 << 9) +
215 uint32(t[j]&((1<<10)-1))*uint32(Q)) >> 10)
216 }
217
218 idx += 5
219 }
220 case 11:
221 var t [8]uint16
222 idx := 0
223 for i := 0; i < N/8; i++ {
224 t[0] = uint16(m[idx]) | (uint16(m[idx+1]) << 8)
225 t[1] = (uint16(m[idx+1]) >> 3) | (uint16(m[idx+2]) << 5)
226 t[2] = (uint16(m[idx+2]) >> 6) | (uint16(m[idx+3]) << 2) | (uint16(m[idx+4]) << 10)
227 t[3] = (uint16(m[idx+4]) >> 1) | (uint16(m[idx+5]) << 7)
228 t[4] = (uint16(m[idx+5]) >> 4) | (uint16(m[idx+6]) << 4)
229 t[5] = (uint16(m[idx+6]) >> 7) | (uint16(m[idx+7]) << 1) | (uint16(m[idx+8]) << 9)
230 t[6] = (uint16(m[idx+8]) >> 2) | (uint16(m[idx+9]) << 6)
231 t[7] = (uint16(m[idx+9]) >> 5) | (uint16(m[idx+10]) << 3)
232
233 for j := 0; j < 8; j++ {
234 p[8*i+j] = int16(((1 << 10) +
235 uint32(t[j]&((1<<11)-1))*uint32(Q)) >> 11)
236 }
237
238 idx += 11
239 }
240 default:
241 panic("unsupported d")
242 }
243 }
244
245 // Writes Compress_q(p, d) to m.
246 //
247 // Assumes p is normalized and d is in {4, 5, 10, 11}.
248 func (p *Poly) CompressTo(m []byte, d int) {
249 // Compress_q(x, d) = ⌈(2ᵈ/q)x⌋ mod⁺ 2ᵈ
250 // = ⌊(2ᵈ/q)x+½⌋ mod⁺ 2ᵈ
251 // = ⌊((x << d) + q/2) / q⌋ mod⁺ 2ᵈ
252 // = DIV((x << d) + q/2, q) & ((1<<d) - 1)
253 //
254 // We approximate DIV(x, q) by computing (x*a)>>e, where a/(2^e) ≈ 1/q.
255 // For d in {10,11} we use 20,642,679/2^36, which computes division by x/q
256 // correctly for 0 ≤ x < 41,522,616, which fits (q << 11) + q/2 comfortably.
257 // For d in {4,5} we use 315/2^20, which doesn't compute division by x/q
258 // correctly for all inputs, but it's close enough that the end result
259 // of the compression is correct. The advantage is that we do not need
260 // to use a 64-bit intermediate value.
261 switch d {
262 case 4:
263 var t [8]uint16
264 idx := 0
265 for i := 0; i < N/8; i++ {
266 for j := 0; j < 8; j++ {
267 t[j] = uint16((((uint32(p[8*i+j])<<4)+uint32(Q)/2)*315)>>
268 20) & ((1 << 4) - 1)
269 }
270 m[idx] = byte(t[0]) | byte(t[1]<<4)
271 m[idx+1] = byte(t[2]) | byte(t[3]<<4)
272 m[idx+2] = byte(t[4]) | byte(t[5]<<4)
273 m[idx+3] = byte(t[6]) | byte(t[7]<<4)
274 idx += 4
275 }
276
277 case 5:
278 var t [8]uint16
279 idx := 0
280 for i := 0; i < N/8; i++ {
281 for j := 0; j < 8; j++ {
282 t[j] = uint16((((uint32(p[8*i+j])<<5)+uint32(Q)/2)*315)>>
283 20) & ((1 << 5) - 1)
284 }
285 m[idx] = byte(t[0]) | byte(t[1]<<5)
286 m[idx+1] = byte(t[1]>>3) | byte(t[2]<<2) | byte(t[3]<<7)
287 m[idx+2] = byte(t[3]>>1) | byte(t[4]<<4)
288 m[idx+3] = byte(t[4]>>4) | byte(t[5]<<1) | byte(t[6]<<6)
289 m[idx+4] = byte(t[6]>>2) | byte(t[7]<<3)
290 idx += 5
291 }
292
293 case 10:
294 var t [4]uint16
295 idx := 0
296 for i := 0; i < N/4; i++ {
297 for j := 0; j < 4; j++ {
298 t[j] = uint16((uint64((uint32(p[4*i+j])<<10)+uint32(Q)/2)*
299 20642679)>>36) & ((1 << 10) - 1)
300 }
301 m[idx] = byte(t[0])
302 m[idx+1] = byte(t[0]>>8) | byte(t[1]<<2)
303 m[idx+2] = byte(t[1]>>6) | byte(t[2]<<4)
304 m[idx+3] = byte(t[2]>>4) | byte(t[3]<<6)
305 m[idx+4] = byte(t[3] >> 2)
306 idx += 5
307 }
308 case 11:
309 var t [8]uint16
310 idx := 0
311 for i := 0; i < N/8; i++ {
312 for j := 0; j < 8; j++ {
313 t[j] = uint16((uint64((uint32(p[8*i+j])<<11)+uint32(Q)/2)*
314 20642679)>>36) & ((1 << 11) - 1)
315 }
316 m[idx] = byte(t[0])
317 m[idx+1] = byte(t[0]>>8) | byte(t[1]<<3)
318 m[idx+2] = byte(t[1]>>5) | byte(t[2]<<6)
319 m[idx+3] = byte(t[2] >> 2)
320 m[idx+4] = byte(t[2]>>10) | byte(t[3]<<1)
321 m[idx+5] = byte(t[3]>>7) | byte(t[4]<<4)
322 m[idx+6] = byte(t[4]>>4) | byte(t[5]<<7)
323 m[idx+7] = byte(t[5] >> 1)
324 m[idx+8] = byte(t[5]>>9) | byte(t[6]<<2)
325 m[idx+9] = byte(t[6]>>6) | byte(t[7]<<5)
326 m[idx+10] = byte(t[7] >> 3)
327 idx += 11
328 }
329 default:
330 panic("unsupported d")
331 }
332 }
333