02. Kernel Architecture

Five cuTile kernel types, fused attention with on-chip decompression, and Blackwell-specific optimizations.

The Five Kernel Types

The engine is built from five kernel types that cover the full compression-and-attention pipeline. Each has 2-bit and 3-bit variants (matching the key and value bit budgets), bringing the total to around 10 individual kernels. All are written in NVIDIA cuTile, a Python DSL that compiles to optimized GPU code targeting Tensor Cores and TMA.

Key Compression
Rotate → Lloyd-Max quantize → store QJL signs
Value Compression
Rotate → Lloyd-Max quantize (3-bit, 8 centroids)
Value Decompression
Centroid lookup → un-rotate via ΠT
Attention Scoring
MSE dot product + QJL correction term
★ Fused Attention
Score + softmax + V accumulation in one pass
Five kernel types covering the full TurboQuant pipeline.

Fused Attention Kernel

The fused attention kernel is where things get interesting. It performs scoring, bias correction, softmax, and the weighted value accumulation in a single GPU kernel with no intermediate writes to global memory.

How it works

Each thread block owns a single query tile. It loads Q once into on-chip memory and then streams through every KV block in sequence:

Per KV Block

  1. ct.mma(Q, K_mseT): bulk attention score against MSE-compressed keys
  2. ct.mma(Q_proj, SignsT): QJL correction inner product
  3. Combine: score = term1 + ||r|| · correction_scale · term2
  4. Online softmax update (running max + sum of exponentials)
  5. ct.mma(P, V): weighted value accumulation

TMA pulls compressed data into shared memory, it flows through TMEM to the Tensor Cores in a single hop (a Blackwell-specific path), and at the end you have your output. One kernel, one pass, no round-trips to HBM.

Fused attention: Q tile stays fixed, KV blocks stream through, output accumulates.

Kernel signature

turboquant_cutile/attention.py
@ct.kernel(occupancy=2)
def turboquant_fused_attention_vfused_3bit(
    Q, K_mse, Signs, R_norms, Q_proj,
    V_Indices, V_Norms, Pi, Output,
    scale: float,
    correction_scale: float,
    seq_k: int,
    # ...8 centroid values unpacked as constants...
    USE_SWIZZLE: ConstBool,
):
    q_block = ct.bid(0)
    # Load Q tile once
    q_tile = ct.load(Q, index=(q_block, 0), shape=(BLOCK_Q, HEAD_DIM))

    # Load rotation matrix Pi into shared memory (resident)
    pi_tile = ct.load(Pi, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))

    for kv_block in range(num_kv_blocks):
        # Score: Q @ K_mse^T
        k_tile = ct.load(K_mse, index=(kv_block, 0), ...)
        term1 = ct.mma(q_tile, k_tile, transpose_b=True)

        # QJL correction: Q_proj @ Signs^T
        s_tile = ct.load(Signs, index=(kv_block, 0), ...)
        term2 = ct.mma(qp_tile, s_tile, transpose_b=True)

        # Combine + online softmax
        scores = term1 * scale + norms * correction_scale * term2
        m_new = ct.max(m_prev, ct.max(scores))
        p_tile = ct.exp2((scores - m_new) * INV_LOG_2)
        acc = acc * ct.exp2((m_prev - m_new) * INV_LOG_2) + ct.mma(p_tile, v_tile)
        ...

V-Fused Decompression

Rather than decompressing values in a separate kernel (writing FP16 to HBM, then reading it back), the fused attention kernel decompresses values on-chip inside the KV block loop.

The rotation matrix Π (128×128, 32 KB in FP16) stays resident in shared memory across all KV blocks. For each block, the kernel loads raw uint8 indices and FP16 norms, does the centroid lookup, un-rotates via ct.mma with Π, scales by norms, and feeds the result straight into the P × V accumulation.

Result: One fewer kernel launch, one fewer HBM round-trip. For a 2048-token sequence, that eliminates 512 KB of unnecessary memory traffic per head.

Blackwell-Specific Optimizations

Block Swizzling

Thread block IDs are interleaved with a stride of 4 to distribute L2 cache pressure. Adjacent blocks hit different cache lines instead of thrashing the same partition. At 16K tokens, swizzle cuts fused attention latency from 954 μs to 899 μs.

Pipelined TMA Loads

latency=2 hints tell the hardware to prefetch the next KV block while the current one is being processed. Overlaps memory transfer with Tensor Core compute.

exp2 with Flush-to-Zero

ct.exp2(flush_to_zero=True) uses Blackwell's native base-2 exponential fast path. Softmax is reformulated in base-2 to take advantage.

Approximate Division

rounding_mode=RMd.APPROX trades a tiny precision delta for faster reciprocal operations in the softmax normalization.

Occupancy Tuning

@ct.kernel(occupancy=2) targets 2 concurrent thread blocks per SM, balancing register pressure against parallelism for latency hiding.

Online Softmax

Running max and exponential sum are maintained across KV blocks. No separate softmax pass needed. Numerically stable by construction.

Why cuTile?

cuTile is NVIDIA's Python DSL for tile-level GPU programming. The compiler handles Tensor Core scheduling, TMA dispatch, register allocation, and shared memory layout. You write Python. It compiles to optimized GPU code. No raw CUDA, no manual pointer arithmetic.

This is particularly powerful on Blackwell, where the TMEM → Tensor Core path and hardware TMA are first-class features that cuTile exposes directly.

← 01. Algorithm 03. Results →
Anirudh Bharadwaj Vangara
Anirudh Bharadwaj Vangara
MLE Intern @ Shopify · Computer Engineering @ University of Waterloo · MLH Top 50
· LinkedIn · X / Twitter