// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense

// The coarse rasterizer stage of the pipeline.
//
// As input we have the ordered partitions of paths from the binning phase and
// the annotated tile list of segments and backdrop per path.
//
// Each workgroup operating on one bin by stream compacting
// the elements corresponding to the bin.
//
// As output we have an ordered command stream per tile. Every tile from a path (backdrop + segment list) will be encoded.

#version 450
#extension GL_GOOGLE_include_directive : enable

#include "mem.h"
#include "setup.h"

layout(local_size_x = N_TILE, local_size_y = 1) in;

layout(set = 0, binding = 1) readonly buffer ConfigBuf {
    Config conf;
};

#include "annotated.h"
#include "bins.h"
#include "tile.h"
#include "ptcl.h"

#define LG_N_PART_READ (7 + LG_WG_FACTOR)
#define N_PART_READ (1 << LG_N_PART_READ)

shared uint sh_elements[N_TILE];

// Number of elements in the partition; prefix sum.
shared uint sh_part_count[N_PART_READ];
shared Alloc sh_part_elements[N_PART_READ];

shared uint sh_bitmaps[N_SLICE][N_TILE];

shared uint sh_tile_count[N_TILE];
// The width of the tile rect for the element, intersected with this bin
shared uint sh_tile_width[N_TILE];
shared uint sh_tile_x0[N_TILE];
shared uint sh_tile_y0[N_TILE];

// These are set up so base + tile_y * stride + tile_x points to a Tile.
shared uint sh_tile_base[N_TILE];
shared uint sh_tile_stride[N_TILE];

#ifdef MEM_DEBUG
// Store allocs only when MEM_DEBUG to save shared memory traffic.
shared Alloc sh_tile_alloc[N_TILE];

void write_tile_alloc(uint el_ix, Alloc a) {
    sh_tile_alloc[el_ix] = a;
}

Alloc read_tile_alloc(uint el_ix, bool mem_ok) {
    return sh_tile_alloc[el_ix];
}
#else
void write_tile_alloc(uint el_ix, Alloc a) {
    // No-op
}

Alloc read_tile_alloc(uint el_ix, bool mem_ok) {
    // All memory.
    return new_alloc(0, memory.length()*4, mem_ok);
}
#endif

// The maximum number of commands per annotated element.
#define ANNO_COMMANDS 2

// Perhaps cmd_alloc should be a global? This is a style question.
bool alloc_cmd(inout Alloc cmd_alloc, inout CmdRef cmd_ref, inout uint cmd_limit) {
    if (cmd_ref.offset < cmd_limit) {
        return true;
    }
    MallocResult new_cmd = malloc(PTCL_INITIAL_ALLOC);
    if (new_cmd.failed) {
        return false;
    }
    CmdJump jump = CmdJump(new_cmd.alloc.offset);
    Cmd_Jump_write(cmd_alloc, cmd_ref, jump);
    cmd_alloc = new_cmd.alloc;
    cmd_ref = CmdRef(cmd_alloc.offset);
    // Reserve space for the maximum number of commands and a potential jump.
    cmd_limit = cmd_alloc.offset + PTCL_INITIAL_ALLOC - (ANNO_COMMANDS + 1) * Cmd_size;
    return true;
}

void write_fill(Alloc alloc, inout CmdRef cmd_ref, uint flags, Tile tile, float linewidth) {
    if (fill_mode_from_flags(flags) == MODE_NONZERO) {
        if (tile.tile.offset != 0) {
            CmdFill cmd_fill = CmdFill(tile.tile.offset, tile.backdrop);
            Cmd_Fill_write(alloc, cmd_ref, cmd_fill);
            cmd_ref.offset += 4 + CmdFill_size;
        } else {
            Cmd_Solid_write(alloc, cmd_ref);
            cmd_ref.offset += 4;
        }
    } else {
        CmdStroke cmd_stroke = CmdStroke(tile.tile.offset, 0.5 * linewidth);
        Cmd_Stroke_write(alloc, cmd_ref, cmd_stroke);
        cmd_ref.offset += 4 + CmdStroke_size;
    }
}

void main() {
    // Could use either linear or 2d layouts for both dispatch and
    // invocations within the workgroup. We'll use variables to abstract.
    uint width_in_bins = (conf.width_in_tiles + N_TILE_X - 1)/N_TILE_X;
    uint bin_ix = width_in_bins * gl_WorkGroupID.y + gl_WorkGroupID.x;
    uint partition_ix = 0;
    uint n_partitions = (conf.n_elements + N_TILE - 1) / N_TILE;
    uint th_ix = gl_LocalInvocationID.x;

    // Coordinates of top left of bin, in tiles.
    uint bin_tile_x = N_TILE_X * gl_WorkGroupID.x;
    uint bin_tile_y = N_TILE_Y * gl_WorkGroupID.y;

    // Per-tile state
    uint tile_x = gl_LocalInvocationID.x % N_TILE_X;
    uint tile_y = gl_LocalInvocationID.x / N_TILE_X;
    uint this_tile_ix = (bin_tile_y + tile_y) * conf.width_in_tiles + bin_tile_x + tile_x;
    Alloc cmd_alloc = slice_mem(conf.ptcl_alloc, this_tile_ix * PTCL_INITIAL_ALLOC, PTCL_INITIAL_ALLOC);
    CmdRef cmd_ref = CmdRef(cmd_alloc.offset);
    // Reserve space for the maximum number of commands and a potential jump.
    uint cmd_limit = cmd_ref.offset + PTCL_INITIAL_ALLOC - (ANNO_COMMANDS + 1) * Cmd_size;
    // The nesting depth of the clip stack
    uint clip_depth = 0;
    // State for the "clip zero" optimization. If it's nonzero, then we are
    // currently in a clip for which the entire tile has an alpha of zero, and
    // the value is the depth after the "begin clip" of that element.
    uint clip_zero_depth = 0;
    // State for the "clip one" optimization. If bit `i` is set, then that means
    // that the clip pushed at depth `i` has an alpha of all one.
    uint clip_one_mask = 0;

    // I'm sure we can figure out how to do this with at least one fewer register...
    // Items up to rd_ix have been read from sh_elements
    uint rd_ix = 0;
    // Items up to wr_ix have been written into sh_elements
    uint wr_ix = 0;
    // Items between part_start_ix and ready_ix are ready to be transferred from sh_part_elements
    uint part_start_ix = 0;
    uint ready_ix = 0;

    // Leave room for the fine rasterizer scratch allocation.
    Alloc scratch_alloc = slice_mem(cmd_alloc, 0, Alloc_size);
    cmd_ref.offset += Alloc_size;

    uint num_begin_slots = 0;
    uint begin_slot = 0;
    bool mem_ok = mem_error == NO_ERROR;
    while (true) {
        for (uint i = 0; i < N_SLICE; i++) {
            sh_bitmaps[i][th_ix] = 0;
        }

        // parallel read of input partitions
        do {
            if (ready_ix == wr_ix && partition_ix < n_partitions) {
                part_start_ix = ready_ix;
                uint count = 0;
                if (th_ix < N_PART_READ && partition_ix + th_ix < n_partitions) {
                    uint in_ix = (conf.bin_alloc.offset >> 2) + ((partition_ix + th_ix) * N_TILE + bin_ix) * 2;
                    count = read_mem(conf.bin_alloc, in_ix);
                    uint offset = read_mem(conf.bin_alloc, in_ix + 1);
                    sh_part_elements[th_ix] = new_alloc(offset, count*BinInstance_size, mem_ok);
                }
                // prefix sum of counts
                for (uint i = 0; i < LG_N_PART_READ; i++) {
                    if (th_ix < N_PART_READ) {
                        sh_part_count[th_ix] = count;
                    }
                    barrier();
                    if (th_ix < N_PART_READ) {
                        if (th_ix >= (1 << i)) {
                            count += sh_part_count[th_ix - (1 << i)];
                        }
                    }
                    barrier();
                }
                if (th_ix < N_PART_READ) {
                    sh_part_count[th_ix] = part_start_ix + count;
                }
                barrier();
                ready_ix = sh_part_count[N_PART_READ - 1];
                partition_ix += N_PART_READ;
            }
            // use binary search to find element to read
            uint ix = rd_ix + th_ix;
            if (ix >= wr_ix && ix < ready_ix && mem_ok) {
                uint part_ix = 0;
                for (uint i = 0; i < LG_N_PART_READ; i++) {
                    uint probe = part_ix + ((N_PART_READ / 2) >> i);
                    if (ix >= sh_part_count[probe - 1]) {
                        part_ix = probe;
                    }
                }
                ix -= part_ix > 0 ? sh_part_count[part_ix - 1] : part_start_ix;
                Alloc bin_alloc = sh_part_elements[part_ix];
                BinInstanceRef inst_ref = BinInstanceRef(bin_alloc.offset);
                BinInstance inst = BinInstance_read(bin_alloc, BinInstance_index(inst_ref, ix));
                sh_elements[th_ix] = inst.element_ix;
            }
            barrier();

            wr_ix = min(rd_ix + N_TILE, ready_ix);
        } while (wr_ix - rd_ix < N_TILE && (wr_ix < ready_ix || partition_ix < n_partitions));

        // We've done the merge and filled the buffer.

        // Read one element, compute coverage.
        uint tag = Annotated_Nop;
        uint element_ix;
        AnnotatedRef ref;
        if (th_ix + rd_ix < wr_ix) {
            element_ix = sh_elements[th_ix];
            ref = AnnotatedRef(conf.anno_alloc.offset + element_ix * Annotated_size);
            tag = Annotated_tag(conf.anno_alloc, ref).tag;
        }

        // Bounding box of element in pixel coordinates.
        uint tile_count;
        switch (tag) {
        case Annotated_Color:
        case Annotated_Image:
        case Annotated_BeginClip:
        case Annotated_EndClip:
            // We have one "path" for each element, even if the element isn't
            // actually a path (currently EndClip, but images etc in the future).
            uint path_ix = element_ix;
            Path path = Path_read(conf.tile_alloc, PathRef(conf.tile_alloc.offset + path_ix * Path_size));
            uint stride = path.bbox.z - path.bbox.x;
            sh_tile_stride[th_ix] = stride;
            int dx = int(path.bbox.x) - int(bin_tile_x);
            int dy = int(path.bbox.y) - int(bin_tile_y);
            int x0 = clamp(dx, 0, N_TILE_X);
            int y0 = clamp(dy, 0, N_TILE_Y);
            int x1 = clamp(int(path.bbox.z) - int(bin_tile_x), 0, N_TILE_X);
            int y1 = clamp(int(path.bbox.w) - int(bin_tile_y), 0, N_TILE_Y);
            sh_tile_width[th_ix] = uint(x1 - x0);
            sh_tile_x0[th_ix] = x0;
            sh_tile_y0[th_ix] = y0;
            tile_count = uint(x1 - x0) * uint(y1 - y0);
            // base relative to bin
            uint base = path.tiles.offset - uint(dy * stride + dx) * Tile_size;
            sh_tile_base[th_ix] = base;
            Alloc path_alloc = new_alloc(path.tiles.offset, (path.bbox.z - path.bbox.x) * (path.bbox.w - path.bbox.y) * Tile_size, mem_ok);
            write_tile_alloc(th_ix, path_alloc);
            break;
        default:
            tile_count = 0;
            break;
        }

        // Prefix sum of sh_tile_count
        sh_tile_count[th_ix] = tile_count;
        for (uint i = 0; i < LG_N_TILE; i++) {
            barrier();
            if (th_ix >= (1 << i)) {
                tile_count += sh_tile_count[th_ix - (1 << i)];
            }
            barrier();
            sh_tile_count[th_ix] = tile_count;
        }
        barrier();
        uint total_tile_count = sh_tile_count[N_TILE - 1];
        for (uint ix = th_ix; ix < total_tile_count; ix += N_TILE) {
            // Binary search to find element
            uint el_ix = 0;
            for (uint i = 0; i < LG_N_TILE; i++) {
                uint probe = el_ix + ((N_TILE / 2) >> i);
                if (ix >= sh_tile_count[probe - 1]) {
                    el_ix = probe;
                }
            }
            AnnotatedRef ref = AnnotatedRef(conf.anno_alloc.offset + sh_elements[el_ix] * Annotated_size);
            uint tag = Annotated_tag(conf.anno_alloc, ref).tag;
            uint seq_ix = ix - (el_ix > 0 ? sh_tile_count[el_ix - 1] : 0);
            uint width = sh_tile_width[el_ix];
            uint x = sh_tile_x0[el_ix] + seq_ix % width;
            uint y = sh_tile_y0[el_ix] + seq_ix / width;
            bool include_tile = false;
            if (tag == Annotated_BeginClip || tag == Annotated_EndClip) {
                include_tile = true;
            } else if (mem_ok) {
                Tile tile = Tile_read(read_tile_alloc(el_ix, mem_ok), TileRef(sh_tile_base[el_ix] + (sh_tile_stride[el_ix] * y + x) * Tile_size));
                // Include the path in the tile if
                // - the tile contains at least a segment (tile offset non-zero)
                // - the tile is completely covered (backdrop non-zero)
                include_tile = tile.tile.offset != 0 || tile.backdrop != 0;
            }
            if (include_tile) {
                uint el_slice = el_ix / 32;
                uint el_mask = 1 << (el_ix & 31);
                atomicOr(sh_bitmaps[el_slice][y * N_TILE_X + x], el_mask);
            }
        }

        barrier();

        // Output non-segment elements for this tile. The thread does a sequential walk
        // through the non-segment elements.
        uint slice_ix = 0;
        uint bitmap = sh_bitmaps[0][th_ix];
        while (mem_ok) {
            if (bitmap == 0) {
                slice_ix++;
                if (slice_ix == N_SLICE) {
                    break;
                }
                bitmap = sh_bitmaps[slice_ix][th_ix];
                if (bitmap == 0) {
                    continue;
                }
            }
            uint element_ref_ix = slice_ix * 32 + findLSB(bitmap);
            uint element_ix = sh_elements[element_ref_ix];

            // Clear LSB
            bitmap &= bitmap - 1;

            // At this point, we read the element again from global memory.
            // If that turns out to be expensive, maybe we can pack it into
            // shared memory (or perhaps just the tag).
            ref = AnnotatedRef(conf.anno_alloc.offset + element_ix * Annotated_size);
            AnnotatedTag tag = Annotated_tag(conf.anno_alloc, ref);

            if (clip_zero_depth == 0) {
                switch (tag.tag) {
                case Annotated_Color:
                    Tile tile = Tile_read(read_tile_alloc(element_ref_ix, mem_ok), TileRef(sh_tile_base[element_ref_ix]
                        + (sh_tile_stride[element_ref_ix] * tile_y + tile_x) * Tile_size));
                    AnnoColor fill = Annotated_Color_read(conf.anno_alloc, ref);
                    if (!alloc_cmd(cmd_alloc, cmd_ref, cmd_limit)) {
                        break;
                    }
                    write_fill(cmd_alloc, cmd_ref, tag.flags, tile, fill.linewidth);
                    Cmd_Color_write(cmd_alloc, cmd_ref, CmdColor(fill.rgba_color));
                    cmd_ref.offset += 4 + CmdColor_size;
                    break;
                case Annotated_Image:
                    tile = Tile_read(read_tile_alloc(element_ref_ix, mem_ok), TileRef(sh_tile_base[element_ref_ix]
                        + (sh_tile_stride[element_ref_ix] * tile_y + tile_x) * Tile_size));
                    AnnoImage fill_img = Annotated_Image_read(conf.anno_alloc, ref);
                    if (!alloc_cmd(cmd_alloc, cmd_ref, cmd_limit)) {
                        break;
                    }
                    write_fill(cmd_alloc, cmd_ref, tag.flags, tile, fill_img.linewidth);
                    Cmd_Image_write(cmd_alloc, cmd_ref, CmdImage(fill_img.index, fill_img.offset));
                    cmd_ref.offset += 4 + CmdImage_size;
                    break;
                case Annotated_BeginClip:
                    tile = Tile_read(read_tile_alloc(element_ref_ix, mem_ok), TileRef(sh_tile_base[element_ref_ix]
                        + (sh_tile_stride[element_ref_ix] * tile_y + tile_x) * Tile_size));
                    if (tile.tile.offset == 0 && tile.backdrop == 0) {
                        clip_zero_depth = clip_depth + 1;
                    } else if (tile.tile.offset == 0 && clip_depth < 32) {
                        clip_one_mask |= (1 << clip_depth);
                    } else {
                        AnnoBeginClip begin_clip = Annotated_BeginClip_read(conf.anno_alloc, ref);
                        if (!alloc_cmd(cmd_alloc, cmd_ref, cmd_limit)) {
                            break;
                        }
                        write_fill(cmd_alloc, cmd_ref, tag.flags, tile, begin_clip.linewidth);
                        Cmd_BeginClip_write(cmd_alloc, cmd_ref);
                        cmd_ref.offset += 4;
                        if (clip_depth < 32) {
                            clip_one_mask &= ~(1 << clip_depth);
                        }
                        begin_slot++;
                        num_begin_slots = max(num_begin_slots, begin_slot);
                    }
                    clip_depth++;
                    break;
                case Annotated_EndClip:
                    clip_depth--;
                    if (clip_depth >= 32 || (clip_one_mask & (1 << clip_depth)) == 0) {
                        if (!alloc_cmd(cmd_alloc, cmd_ref, cmd_limit)) {
                            break;
                        }
                        Cmd_Solid_write(cmd_alloc, cmd_ref);
                        cmd_ref.offset += 4;
                        begin_slot--;
                        Cmd_EndClip_write(cmd_alloc, cmd_ref);
                        cmd_ref.offset += 4;
                    }
                    break;
                }
            } else {
                // In "clip zero" state, suppress all drawing
                switch (tag.tag) {
                case Annotated_BeginClip:
                    clip_depth++;
                    break;
                case Annotated_EndClip:
                    if (clip_depth == clip_zero_depth) {
                        clip_zero_depth = 0;
                    }
                    clip_depth--;
                    break;
                }
            }
        }
        barrier();

        rd_ix += N_TILE;
        if (rd_ix >= ready_ix && partition_ix >= n_partitions) break;
    }
    if (bin_tile_x + tile_x < conf.width_in_tiles && bin_tile_y + tile_y < conf.height_in_tiles) {
        Cmd_End_write(cmd_alloc, cmd_ref);
        if (num_begin_slots > 0) {
            // Write scratch allocation: one state per BeginClip per rasterizer chunk.
            uint scratch_size = num_begin_slots * TILE_WIDTH_PX * TILE_HEIGHT_PX * CLIP_STATE_SIZE * 4;
            MallocResult scratch = malloc(scratch_size);
            // Ignore scratch.failed; we don't use the allocation and kernel4
            // checks for memory overflow before using it.
            alloc_write(scratch_alloc, scratch_alloc.offset, scratch.alloc);
        }
    }
}
