#include "textflag.h"

// func gnarlAccumulateAVX2(acc *[32]uint32, basis *[12][27][32]uint16, block *[41]byte)
//
// Scans 324 bits in block (41 bytes). For each set bit at flat index n,
// adds basis row n (32 uint16, zero-extended to uint32) into acc[0..31]
// using AVX2 VPMOVZXWD + VPADDD.
//
// Processes block byte-by-byte. Within each byte, uses TZCNT+BLSR to
// iterate only over set bits, reducing iterations from 324 to ~162.
//
// Register allocation:
//   BX  = acc pointer
//   SI  = basis pointer
//   DI  = block pointer
//   R8  = byte index (0..40)
//   R9  = current byte value (bits being processed)
//   R10 = bit base for current byte (R8 * 8)
//   R11 = scratch (row pointer)
//   AX  = TZCNT result (bit position within byte)
//   CX  = flat bit index
//   Y0-Y3 = acc[0..7], acc[8..15], acc[16..23], acc[24..31]
//   Y4-Y7 = scratch for VPMOVZXWD
TEXT ·gnarlAccumulateAVX2(SB), NOSPLIT, $0-24
	MOVQ acc+0(FP), BX
	MOVQ basis+8(FP), SI
	MOVQ block+16(FP), DI

	// Zero accumulators.
	VPXOR Y0, Y0, Y0
	VPXOR Y1, Y1, Y1
	VPXOR Y2, Y2, Y2
	VPXOR Y3, Y3, Y3

	XORQ R8, R8          // byteIdx = 0

byteloop:
	CMPQ R8, $41
	JGE done

	// Load block byte, skip if zero (no set bits).
	MOVBQZX (DI)(R8*1), R9
	TESTQ R9, R9
	JZ nextbyte

	// R10 = bit base for this byte = R8 * 8.
	MOVQ R8, R10
	SHLQ $3, R10

bitloop:
	// Find lowest set bit in R9.
	TZCNTQ R9, AX        // AX = position of lowest set bit

	// Flat bit index = R10 + AX.
	LEAQ (R10)(AX*1), CX

	// Bounds check: skip if >= 324 (last byte has only 4 valid bits).
	CMPQ CX, $324
	JGE clearbits

	// Row pointer = basis + CX * 64.
	MOVQ CX, R11
	SHLQ $6, R11
	ADDQ SI, R11

	// Load 4 groups of 8 uint16, zero-extend to uint32, add to acc.
	VPMOVZXWD (R11), Y4
	VPADDD Y4, Y0, Y0

	VPMOVZXWD 16(R11), Y5
	VPADDD Y5, Y1, Y1

	VPMOVZXWD 32(R11), Y6
	VPADDD Y6, Y2, Y2

	VPMOVZXWD 48(R11), Y7
	VPADDD Y7, Y3, Y3

clearbits:
	// Clear lowest set bit: R9 = R9 & (R9 - 1). This is BLSR.
	BLSRQ R9, R9
	JNZ bitloop           // more bits remaining

nextbyte:
	INCQ R8
	JMP byteloop

done:
	// Store accumulators to acc.
	VMOVDQU Y0, (BX)
	VMOVDQU Y1, 32(BX)
	VMOVDQU Y2, 64(BX)
	VMOVDQU Y3, 96(BX)

	VZEROUPPER
	RET
