02. Kernel Architecture
Five cuTile kernel types, fused attention with on-chip decompression, and Blackwell-specific optimizations.
The Five Kernel Types
The engine is built from five kernel types that cover the full compression-and-attention pipeline. Each has 2-bit and 3-bit variants (matching the key and value bit budgets), bringing the total to around 10 individual kernels. All are written in NVIDIA cuTile, a Python DSL that compiles to optimized GPU code targeting Tensor Cores and TMA.
Fused Attention Kernel
The fused attention kernel is where things get interesting. It performs scoring, bias correction, softmax, and the weighted value accumulation in a single GPU kernel with no intermediate writes to global memory.
How it works
Each thread block owns a single query tile. It loads Q once into on-chip memory and then streams through every KV block in sequence:
Per KV Block
ct.mma(Q, K_mseT): bulk attention score against MSE-compressed keysct.mma(Q_proj, SignsT): QJL correction inner product- Combine:
score = term1 + ||r|| · correction_scale · term2 - Online softmax update (running max + sum of exponentials)
ct.mma(P, V): weighted value accumulation
TMA pulls compressed data into shared memory, it flows through TMEM to the Tensor Cores in a single hop (a Blackwell-specific path), and at the end you have your output. One kernel, one pass, no round-trips to HBM.
Kernel signature
@ct.kernel(occupancy=2)
def turboquant_fused_attention_vfused_3bit(
Q, K_mse, Signs, R_norms, Q_proj,
V_Indices, V_Norms, Pi, Output,
scale: float,
correction_scale: float,
seq_k: int,
# ...8 centroid values unpacked as constants...
USE_SWIZZLE: ConstBool,
):
q_block = ct.bid(0)
# Load Q tile once
q_tile = ct.load(Q, index=(q_block, 0), shape=(BLOCK_Q, HEAD_DIM))
# Load rotation matrix Pi into shared memory (resident)
pi_tile = ct.load(Pi, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
for kv_block in range(num_kv_blocks):
# Score: Q @ K_mse^T
k_tile = ct.load(K_mse, index=(kv_block, 0), ...)
term1 = ct.mma(q_tile, k_tile, transpose_b=True)
# QJL correction: Q_proj @ Signs^T
s_tile = ct.load(Signs, index=(kv_block, 0), ...)
term2 = ct.mma(qp_tile, s_tile, transpose_b=True)
# Combine + online softmax
scores = term1 * scale + norms * correction_scale * term2
m_new = ct.max(m_prev, ct.max(scores))
p_tile = ct.exp2((scores - m_new) * INV_LOG_2)
acc = acc * ct.exp2((m_prev - m_new) * INV_LOG_2) + ct.mma(p_tile, v_tile)
...
V-Fused Decompression
Rather than decompressing values in a separate kernel (writing FP16 to HBM, then reading it back), the fused attention kernel decompresses values on-chip inside the KV block loop.
The rotation matrix Π (128×128, 32 KB in FP16) stays resident in shared memory across
all KV blocks. For each block, the kernel loads raw uint8 indices and FP16
norms, does the centroid lookup, un-rotates via ct.mma with Π, scales by norms,
and feeds the result straight into the P × V accumulation.
Blackwell-Specific Optimizations
Block Swizzling
Thread block IDs are interleaved with a stride of 4 to distribute L2 cache pressure. Adjacent blocks hit different cache lines instead of thrashing the same partition. At 16K tokens, swizzle cuts fused attention latency from 954 μs to 899 μs.
Pipelined TMA Loads
latency=2 hints tell the hardware to prefetch the next KV block while
the current one is being processed. Overlaps memory transfer with Tensor Core compute.
exp2 with Flush-to-Zero
ct.exp2(flush_to_zero=True) uses Blackwell's native base-2 exponential
fast path. Softmax is reformulated in base-2 to take advantage.
Approximate Division
rounding_mode=RMd.APPROX trades a tiny precision delta for faster
reciprocal operations in the softmax normalization.
Occupancy Tuning
@ct.kernel(occupancy=2) targets 2 concurrent thread blocks per SM,
balancing register pressure against parallelism for latency hiding.
Online Softmax
Running max and exponential sum are maintained across KV blocks. No separate softmax pass needed. Numerically stable by construction.
Why cuTile?
cuTile is NVIDIA's Python DSL for tile-level GPU programming. The compiler handles Tensor Core scheduling, TMA dispatch, register allocation, and shared memory layout. You write Python. It compiles to optimized GPU code. No raw CUDA, no manual pointer arithmetic.
This is particularly powerful on Blackwell, where the TMEM → Tensor Core path and hardware TMA are first-class features that cuTile exposes directly.