poly.go raw

   1  package common
   2  
   3  // An element of our base ring R which are polynomials over ℤ_q
   4  // modulo the equation Xᴺ = -1, where q=3329 and N=256.
   5  //
   6  // This type is also used to store NTT-transformed polynomials,
   7  // see Poly.NTT().
   8  //
   9  // Coefficients aren't always reduced.  See Normalize().
  10  type Poly [N]int16
  11  
  12  // Sets p to a + b.  Does not normalize coefficients.
  13  func (p *Poly) addGeneric(a, b *Poly) {
  14  	for i := 0; i < N; i++ {
  15  		p[i] = a[i] + b[i]
  16  	}
  17  }
  18  
  19  // Sets p to a - b.  Does not normalize coefficients.
  20  func (p *Poly) subGeneric(a, b *Poly) {
  21  	for i := 0; i < N; i++ {
  22  		p[i] = a[i] - b[i]
  23  	}
  24  }
  25  
  26  // Almost normalizes coefficients.
  27  //
  28  // Ensures each coefficient is in {0, …, q}.
  29  func (p *Poly) barrettReduceGeneric() {
  30  	for i := 0; i < N; i++ {
  31  		p[i] = barrettReduce(p[i])
  32  	}
  33  }
  34  
  35  // Normalizes coefficients.
  36  //
  37  // Ensures each coefficient is in {0, …, q-1}.
  38  func (p *Poly) normalizeGeneric() {
  39  	for i := 0; i < N; i++ {
  40  		p[i] = csubq(barrettReduce(p[i]))
  41  	}
  42  }
  43  
  44  // Multiplies p in-place by the Montgomery factor 2¹⁶.
  45  //
  46  // Coefficients of p can be arbitrary.  Resulting coefficients are bounded
  47  // in absolute value by q.
  48  func (p *Poly) ToMont() {
  49  	for i := 0; i < N; i++ {
  50  		p[i] = toMont(p[i])
  51  	}
  52  }
  53  
  54  // Sets p to the "pointwise" multiplication of a and b.
  55  //
  56  // That is: InvNTT(p) = InvNTT(a) * InvNTT(b).  Assumes a and b are in
  57  // Montgomery form.  Products between coefficients of a and b must be strictly
  58  // bounded in absolute value by 2¹⁵q.  p will be in Montgomery form and
  59  // bounded in absolute value by 2q.
  60  //
  61  // Requires a and b to be in "tangled" order, see Tangle().  p will be in
  62  // tangled order as well.
  63  func (p *Poly) mulHatGeneric(a, b *Poly) {
  64  	// Recall from the discussion in NTT(), that a transformed polynomial is
  65  	// an element of ℤ_q[x]/(x²-ζ) x … x  ℤ_q[x]/(x²+ζ¹²⁷);
  66  	// that is: 128 degree-one polynomials instead of simply 256 elements
  67  	// from ℤ_q as in the regular NTT.  So instead of pointwise multiplication,
  68  	// we multiply the 128 pairs of degree-one polynomials modulo the
  69  	// right equation:
  70  	//
  71  	//  (a₁ + a₂x)(b₁ + b₂x) = a₁b₁ + a₂b₂ζ' + (a₁b₂ + a₂b₁)x,
  72  	//
  73  	// where ζ' is the appropriate power of ζ.
  74  
  75  	k := 64
  76  	for i := 0; i < N; i += 4 {
  77  		zeta := int32(Zetas[k])
  78  		k++
  79  
  80  		p0 := montReduce(int32(a[i+1]) * int32(b[i+1]))
  81  		p0 = montReduce(int32(p0) * zeta)
  82  		p0 += montReduce(int32(a[i]) * int32(b[i]))
  83  
  84  		p1 := montReduce(int32(a[i]) * int32(b[i+1]))
  85  		p1 += montReduce(int32(a[i+1]) * int32(b[i]))
  86  
  87  		p[i] = p0
  88  		p[i+1] = p1
  89  
  90  		p2 := montReduce(int32(a[i+3]) * int32(b[i+3]))
  91  		p2 = -montReduce(int32(p2) * zeta)
  92  		p2 += montReduce(int32(a[i+2]) * int32(b[i+2]))
  93  
  94  		p3 := montReduce(int32(a[i+2]) * int32(b[i+3]))
  95  		p3 += montReduce(int32(a[i+3]) * int32(b[i+2]))
  96  
  97  		p[i+2] = p2
  98  		p[i+3] = p3
  99  	}
 100  }
 101  
 102  // Packs p into buf.  buf should be of length PolySize.
 103  //
 104  // Assumes p is normalized (and not just Barrett reduced) and "tangled",
 105  // see Tangle().
 106  func (p *Poly) Pack(buf []byte) {
 107  	q := *p
 108  	q.Detangle()
 109  	for i := 0; i < 128; i++ {
 110  		t0 := q[2*i]
 111  		t1 := q[2*i+1]
 112  		buf[3*i] = byte(t0)
 113  		buf[3*i+1] = byte(t0>>8) | byte(t1<<4)
 114  		buf[3*i+2] = byte(t1 >> 4)
 115  	}
 116  }
 117  
 118  // Unpacks p from buf.
 119  //
 120  // buf should be of length PolySize.  p will be "tangled", see Detangle().
 121  //
 122  // p will not be normalized; instead 0 ≤ p[i] < 4096.
 123  func (p *Poly) Unpack(buf []byte) {
 124  	for i := 0; i < 128; i++ {
 125  		p[2*i] = int16(buf[3*i]) | ((int16(buf[3*i+1]) << 8) & 0xfff)
 126  		p[2*i+1] = int16(buf[3*i+1]>>4) | (int16(buf[3*i+2]) << 4)
 127  	}
 128  	p.Tangle()
 129  }
 130  
 131  // Set p to Decompress_q(m, 1).
 132  //
 133  // p will be normalized.  m has to be of PlaintextSize.
 134  func (p *Poly) DecompressMessage(m []byte) {
 135  	// Decompress_q(x, 1) = ⌈xq/2⌋ = ⌊xq/2+½⌋ = (xq+1) >> 1 and so
 136  	// Decompress_q(0, 1) = 0 and Decompress_q(1, 1) = (q+1)/2.
 137  	for i := 0; i < 32; i++ {
 138  		for j := 0; j < 8; j++ {
 139  			bit := (m[i] >> uint(j)) & 1
 140  
 141  			// Set coefficient to either 0 or (q+1)/2 depending on the bit.
 142  			p[8*i+j] = -int16(bit) & ((Q + 1) / 2)
 143  		}
 144  	}
 145  }
 146  
 147  // Writes Compress_q(p, 1) to m.
 148  //
 149  // Assumes p is normalized.  m has to be of length at least PlaintextSize.
 150  func (p *Poly) CompressMessageTo(m []byte) {
 151  	// Compress_q(x, 1) is 1 on {833, …, 2496} and zero elsewhere.
 152  	for i := 0; i < 32; i++ {
 153  		m[i] = 0
 154  		for j := 0; j < 8; j++ {
 155  			x := 1664 - p[8*i+j]
 156  			// With the previous substitution, we want to return 1 if
 157  			// and only if x is in {831, …, -832}.
 158  			x = (x >> 15) ^ x
 159  			// Note (x >> 15)ˣ if x≥0 and -x-1 otherwise. Thus now we want
 160  			// to return 1 iff x ≤ 831, ie. x - 832 < 0.
 161  			x -= 832
 162  			m[i] |= ((byte(x >> 15)) & 1) << uint(j)
 163  		}
 164  	}
 165  }
 166  
 167  // Set p to Decompress_q(m, 1).
 168  //
 169  // Assumes d is in {4, 5, 10, 11}.  p will be normalized.
 170  func (p *Poly) Decompress(m []byte, d int) {
 171  	// Decompress_q(x, d) = ⌈(q/2ᵈ)x⌋
 172  	//                    = ⌊(q/2ᵈ)x+½⌋
 173  	//                    = ⌊(qx + 2ᵈ⁻¹)/2ᵈ⌋
 174  	//                    = (qx + (1<<(d-1))) >> d
 175  	switch d {
 176  	case 4:
 177  		for i := 0; i < N/2; i++ {
 178  			p[2*i] = int16(((1 << 3) +
 179  				uint32(m[i]&15)*uint32(Q)) >> 4)
 180  			p[2*i+1] = int16(((1 << 3) +
 181  				uint32(m[i]>>4)*uint32(Q)) >> 4)
 182  		}
 183  	case 5:
 184  		var t [8]uint16
 185  		idx := 0
 186  		for i := 0; i < N/8; i++ {
 187  			t[0] = uint16(m[idx])
 188  			t[1] = (uint16(m[idx]) >> 5) | (uint16(m[idx+1] << 3))
 189  			t[2] = uint16(m[idx+1]) >> 2
 190  			t[3] = (uint16(m[idx+1]) >> 7) | (uint16(m[idx+2] << 1))
 191  			t[4] = (uint16(m[idx+2]) >> 4) | (uint16(m[idx+3] << 4))
 192  			t[5] = uint16(m[idx+3]) >> 1
 193  			t[6] = (uint16(m[idx+3]) >> 6) | (uint16(m[idx+4] << 2))
 194  			t[7] = uint16(m[idx+4]) >> 3
 195  
 196  			for j := 0; j < 8; j++ {
 197  				p[8*i+j] = int16(((1 << 4) +
 198  					uint32(t[j]&((1<<5)-1))*uint32(Q)) >> 5)
 199  			}
 200  
 201  			idx += 5
 202  		}
 203  
 204  	case 10:
 205  		var t [4]uint16
 206  		idx := 0
 207  		for i := 0; i < N/4; i++ {
 208  			t[0] = uint16(m[idx]) | (uint16(m[idx+1]) << 8)
 209  			t[1] = (uint16(m[idx+1]) >> 2) | (uint16(m[idx+2]) << 6)
 210  			t[2] = (uint16(m[idx+2]) >> 4) | (uint16(m[idx+3]) << 4)
 211  			t[3] = (uint16(m[idx+3]) >> 6) | (uint16(m[idx+4]) << 2)
 212  
 213  			for j := 0; j < 4; j++ {
 214  				p[4*i+j] = int16(((1 << 9) +
 215  					uint32(t[j]&((1<<10)-1))*uint32(Q)) >> 10)
 216  			}
 217  
 218  			idx += 5
 219  		}
 220  	case 11:
 221  		var t [8]uint16
 222  		idx := 0
 223  		for i := 0; i < N/8; i++ {
 224  			t[0] = uint16(m[idx]) | (uint16(m[idx+1]) << 8)
 225  			t[1] = (uint16(m[idx+1]) >> 3) | (uint16(m[idx+2]) << 5)
 226  			t[2] = (uint16(m[idx+2]) >> 6) | (uint16(m[idx+3]) << 2) | (uint16(m[idx+4]) << 10)
 227  			t[3] = (uint16(m[idx+4]) >> 1) | (uint16(m[idx+5]) << 7)
 228  			t[4] = (uint16(m[idx+5]) >> 4) | (uint16(m[idx+6]) << 4)
 229  			t[5] = (uint16(m[idx+6]) >> 7) | (uint16(m[idx+7]) << 1) | (uint16(m[idx+8]) << 9)
 230  			t[6] = (uint16(m[idx+8]) >> 2) | (uint16(m[idx+9]) << 6)
 231  			t[7] = (uint16(m[idx+9]) >> 5) | (uint16(m[idx+10]) << 3)
 232  
 233  			for j := 0; j < 8; j++ {
 234  				p[8*i+j] = int16(((1 << 10) +
 235  					uint32(t[j]&((1<<11)-1))*uint32(Q)) >> 11)
 236  			}
 237  
 238  			idx += 11
 239  		}
 240  	default:
 241  		panic("unsupported d")
 242  	}
 243  }
 244  
 245  // Writes Compress_q(p, d) to m.
 246  //
 247  // Assumes p is normalized and d is in {4, 5, 10, 11}.
 248  func (p *Poly) CompressTo(m []byte, d int) {
 249  	// Compress_q(x, d) = ⌈(2ᵈ/q)x⌋ mod⁺ 2ᵈ
 250  	//                  = ⌊(2ᵈ/q)x+½⌋ mod⁺ 2ᵈ
 251  	//					= ⌊((x << d) + q/2) / q⌋ mod⁺ 2ᵈ
 252  	//					= DIV((x << d) + q/2, q) & ((1<<d) - 1)
 253  	//
 254  	// We approximate DIV(x, q) by computing (x*a)>>e, where a/(2^e) ≈ 1/q.
 255  	// For d in {10,11} we use 20,642,679/2^36, which computes division by x/q
 256  	// correctly for 0 ≤ x < 41,522,616, which fits (q << 11) + q/2 comfortably.
 257  	// For d in {4,5} we use 315/2^20, which doesn't compute division by x/q
 258  	// correctly for all inputs, but it's close enough that the end result
 259  	// of the compression is correct. The advantage is that we do not need
 260  	// to use a 64-bit intermediate value.
 261  	switch d {
 262  	case 4:
 263  		var t [8]uint16
 264  		idx := 0
 265  		for i := 0; i < N/8; i++ {
 266  			for j := 0; j < 8; j++ {
 267  				t[j] = uint16((((uint32(p[8*i+j])<<4)+uint32(Q)/2)*315)>>
 268  					20) & ((1 << 4) - 1)
 269  			}
 270  			m[idx] = byte(t[0]) | byte(t[1]<<4)
 271  			m[idx+1] = byte(t[2]) | byte(t[3]<<4)
 272  			m[idx+2] = byte(t[4]) | byte(t[5]<<4)
 273  			m[idx+3] = byte(t[6]) | byte(t[7]<<4)
 274  			idx += 4
 275  		}
 276  
 277  	case 5:
 278  		var t [8]uint16
 279  		idx := 0
 280  		for i := 0; i < N/8; i++ {
 281  			for j := 0; j < 8; j++ {
 282  				t[j] = uint16((((uint32(p[8*i+j])<<5)+uint32(Q)/2)*315)>>
 283  					20) & ((1 << 5) - 1)
 284  			}
 285  			m[idx] = byte(t[0]) | byte(t[1]<<5)
 286  			m[idx+1] = byte(t[1]>>3) | byte(t[2]<<2) | byte(t[3]<<7)
 287  			m[idx+2] = byte(t[3]>>1) | byte(t[4]<<4)
 288  			m[idx+3] = byte(t[4]>>4) | byte(t[5]<<1) | byte(t[6]<<6)
 289  			m[idx+4] = byte(t[6]>>2) | byte(t[7]<<3)
 290  			idx += 5
 291  		}
 292  
 293  	case 10:
 294  		var t [4]uint16
 295  		idx := 0
 296  		for i := 0; i < N/4; i++ {
 297  			for j := 0; j < 4; j++ {
 298  				t[j] = uint16((uint64((uint32(p[4*i+j])<<10)+uint32(Q)/2)*
 299  					20642679)>>36) & ((1 << 10) - 1)
 300  			}
 301  			m[idx] = byte(t[0])
 302  			m[idx+1] = byte(t[0]>>8) | byte(t[1]<<2)
 303  			m[idx+2] = byte(t[1]>>6) | byte(t[2]<<4)
 304  			m[idx+3] = byte(t[2]>>4) | byte(t[3]<<6)
 305  			m[idx+4] = byte(t[3] >> 2)
 306  			idx += 5
 307  		}
 308  	case 11:
 309  		var t [8]uint16
 310  		idx := 0
 311  		for i := 0; i < N/8; i++ {
 312  			for j := 0; j < 8; j++ {
 313  				t[j] = uint16((uint64((uint32(p[8*i+j])<<11)+uint32(Q)/2)*
 314  					20642679)>>36) & ((1 << 11) - 1)
 315  			}
 316  			m[idx] = byte(t[0])
 317  			m[idx+1] = byte(t[0]>>8) | byte(t[1]<<3)
 318  			m[idx+2] = byte(t[1]>>5) | byte(t[2]<<6)
 319  			m[idx+3] = byte(t[2] >> 2)
 320  			m[idx+4] = byte(t[2]>>10) | byte(t[3]<<1)
 321  			m[idx+5] = byte(t[3]>>7) | byte(t[4]<<4)
 322  			m[idx+6] = byte(t[4]>>4) | byte(t[5]<<7)
 323  			m[idx+7] = byte(t[5] >> 1)
 324  			m[idx+8] = byte(t[5]>>9) | byte(t[6]<<2)
 325  			m[idx+9] = byte(t[6]>>6) | byte(t[7]<<5)
 326  			m[idx+10] = byte(t[7] >> 3)
 327  			idx += 11
 328  		}
 329  	default:
 330  		panic("unsupported d")
 331  	}
 332  }
 333