perf(gemma4-cuda): MMQ + flash-attention prefill + dp4a decode (#141/#142)#143
Conversation
…ec (#142) Close the bulk of the Gemma 4 GPU gap vs llama.cpp on the RTX 4070 Ti (gemma-4-E4B-it-Q8_0): prefill ~1K ctx: 111 -> 1575 t/s (14.2x; llama.cpp ~8000) decode ~1K ctx: 43 -> 50.6 t/s (dp4a + CUDA graphs; llama.cpp 77.7) decode low ctx: 27 -> 58 t/s #141 prefill — route the batched-trunk matmuls through a compute-bound cuBLAS GEMM instead of the memory-bound matvec GEMM-N (which re-streamed each weight per token): - llm_dequant_q8_0_to_f16 + llm_f32_to_f16 kernels feed cublasGemmEx (fp16xfp16 -> fp32, fp32 accum); each Q8_0 weight is dequantized once per batch. CudaBackend.MatMulBatchedGemm + fp16 scratch (grow-only). - F32 PLE projections go through Sgemm's TF32 path (was matvec GEMM-N, re-streaming a 110 MB weight per token -> 106ms; now ~6ms). - PLE host dequant parallelized across tokens. - PrefillGemmEnabled (SHARPI_PREFILL_GEMM) gates it; argmax-stable vs the fp32 sequential reference, not bit-exact (fp16 rounding). #142 decode — profiling showed decode is matvec-bound (FFN+QKV+O = 71%), not attention-bound (6.4%), so the win is a tighter Q8_0 matvec, not flash attention: - llm_matvec_q8_0_dp4a: quantize the activation to Q8_1 and use __dp4a int8 dot products (llama.cpp's MMVQ approach), replacing the per-element int8->float decode. Q80Dp4aEnabled (SHARPI_Q80_DP4A), default on. - EnsureQ81Scratch pre-grows the Q8_1 buffer before the first CUDA-graph capture (capture forbids cudaMalloc). Tests: Gemma4 CUDA suite green (batched-prefill oracle moved to a tolerance model + a dp4a-off bit-exact case; CudaQ8_0/CudaMatMulBatched bit-exact oracles pin Q80Dp4aEnabled=false and a new dp4a tolerance test added). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…de (#142) After #137's launch fusions and the dp4a decode matvec, CUDA-graph capture/replay measures +9-10% on the all-GPU Gemma 4 decode at BOTH low and ~1K context (49.9->was 46, 58.2->was 53 t/s) — the short-context regression that kept #136 default-off is gone. Flip _useCudaGraph to default on (SHARPI_CUDA_GRAPH=0 reverts). Scoped to CudaForwardPass; the hybrid path stays env-gated. Bit-parity already proven by Gemma4CudaGraphParityTests. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…, decode 49→59) Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…-k 64/top-p 0.95, --no-thinking) Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
… model) Gemma 4 E4B-it's chat template renders a <|channel>thought block but the model wasn't trained to fill it, so the default enable_thinking=true made it degenerate into tag-soup. Detect arch == "gemma4" and default thinking off (prints a one-line notice); --no-thinking still forces off for every other model and the effective state flows to the temp-0 warning, prompt formatting, and the MTP gate via s_noThinking. Non-Gemma models are unchanged. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Adds llm_mmq_q8_0: a shared-tiled m16n8k32 s8 mma matmul that multiplies Q8_0 weights by Q8_1-quantized activations directly on int8 tensor cores, reading each weight once as int8 with no fp16 HBM dequant temp (the cost that capped MatMulBatchedGemm). 64x64 output tile, 8 warps x 4 mma/K-block, per-block fp32 scale accumulation, no sum/bias term (Q8_0 symmetric, D4). Wired into the Gemma 4 batched-trunk prefill behind PrefillMmqEnabled / SHARPI_PREFILL_MMQ (default off during bring-up; takes precedence over the GEMM path when on). New CudaBackend.MatMulBatchedMmq + sharpi_uint_at misaligned-word helper. Parity validated in CudaMmqQ8_0Tests (int8 dot is exact; tracks the fp32 DotQ8_0 reference to the dp4a per-row-RMS tolerance). Profiling note: at realistic prompt lengths (N=1848) the prefill matmul is only ~24% of the time; the scalar O(n^2) attention is ~67% (909ms). The MMQ foundation lands here; matmul-kernel tuning + flash-attention prefill follow. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Register-prefetch double-buffering for llm_mmq_q8_0: the next K-tile's global loads are issued into registers while the current tile's mma's run, so global latency hides behind compute instead of stalling the per-K-block barrier (cp.async is unusable — the Q8_0 qs is only 2-byte aligned). Widened the tile to 64x128 (halves the nTok/BN weight re-read factor). Over a 1848-tok Gemma 4 prefill the matmul portion drops ~332ms (cuBLAS GEMM) -> ~316ms (MMQ), so MMQ now beats the dequant->fp16->cuBLAS path and drops its fp16 weight HBM temp. PrefillMmqEnabled now defaults on (SHARPI_PREFILL_MMQ=0 reverts), gated under PrefillGemmEnabled. The GEMM-path oracle pins MMQ off to stay isolated; a new Gemma4_E4B_BatchedPrefill_MmqMatchesSequential asserts MMQ argmax-stability. End-to-end prefill is ~unchanged (1389 t/s): the 16ms matmul win is ~1% of the attention-dominated total. Profiling at N=1848 shows the scalar O(n^2) llm_full_seq_attention is ~67% (909ms) of prefill — the next target. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Replaces the scalar llm_full_seq_attention / llm_attention_swa_batched (one
256-thread block PER query, each re-reading its whole K/V range from global —
O(n^2); for SWA, adjacent queries' 512-wide windows overlap ~99% so K/V is
re-streamed up to ~512x) with llm_flash_attn_prefill_f32: a block handles a
tile of 8 queries of one head and streams K/V through shared-memory tiles with
an online softmax (FlashAttention-2 style — running max/sum + rescaled output
accumulator, no n^2 score buffer). Each key is read from global once per 8
queries instead of once per query. One warp per query; lane L owns head dims
{L,L+32,..}; GQA, causal + optional sliding window, per-layer head_dim (256
SWA / 512 global). fp32 KV. Argmax-stable (online softmax reassociates the
same sum), not bit-exact.
Profiling at N=1848 showed attention was the dominant prefill cost and
memory-bound on redundant K/V reads (~79x off compute peak), not compute-bound.
Result: attention ~929ms -> ~411ms (2.26x), Gemma 4 prefill ~1389 -> ~2180 t/s
(1.57x). Decode unchanged (flash is prefill-only).
CudaBackend.FlashAttentionPrefill (dynamic shared sized to a 48KB budget),
PrefillFlashAttnEnabled / SHARPI_PREFILL_FLASH (default on). Parity validated
in CudaFlashAttnTests (vs both scalar kernels: GQA, hd 256/512, windowing,
partial tiles) + e2e Gemma4_E4B_BatchedPrefill_FlashAttnMatchesSequential. The
matvec/GEMM/MMQ bit-exact oracles pin flash off (it isn't bit-exact). 19
focused tests green.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
FA_QT 8 -> 16 (16 warps / 512 threads per block): doubles the K/V reuse factor (each key read from global once per 16 queries). Attention ~411ms -> ~382ms at N=1848; Gemma 4 prefill ~2180 -> ~2364 t/s. Diminishing return confirms the residual attention cost is now the scalar warp-reduce QK/PV compute, not K/V traffic — tensor-core QK/PV is the next lever. Parity unchanged. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The residual flash-attention cost was the scalar fp32 QK warp-reduce. Stage K as fp16 (half2-packed) in shared and run the QK dot with fma.rn.f16x2 (2 multiply- accumulates/instruction, ~2x the fp32 FMA rate); V stays fp32 for the exact scalar PV. Switched the per-lane dim ownership to adjacent PAIRS (pi = lane+32*p, dims 2*pi/2*pi+1) so the shared half2 K loads stay coalesced. New sharpi_f32x2_to_f16x2 / sharpi_hfma2 / sharpi_f16x2_sum inline-PTX helpers (NVRTC has no cuda_fp16.h). fp16 K also halves the K shared footprint → larger key tiles, fewer barriers. K rounded to fp16 (scores tolerate it — argmax-stable, maxAbs ~1.4e-4 vs the fp32 scalar reference); V/PV/softmax stay fp32. CudaFlashAttnTests tolerance relaxed to 5e-3*rms for the fp16-K relative error. Attention ~382ms -> ~245ms (1.56x); Gemma 4 prefill ~2364 -> ~2847 t/s. Session total: prefill 1389 -> 2847 t/s (2.05x). 18 focused tests green. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The batched-trunk prefill embedded tokens with a per-token EmbedLookupQ8_0 + a device copy — 2*N host-driven launches (~62ms at N=1848, launch-overhead bound now that attention is fast). Collapse it to a single grid.x=N launch: new llm_embed_lookup_q8_0_batched reads token_ids[blockIdx.x] from a device buffer and writes row blockIdx.x of _bpHidden directly (per-token body identical → bit-exact). CudaBackend.EmbedLookupQ8_0Batched; the prefill uploads the ids once and uses it for the Q8_0 embed table (Gemma 4), keeping the per-token loop for other dtypes. Embed ~62ms -> ~17ms; Gemma 4 prefill ~2847 -> ~2920 t/s. Session total: prefill 1389 -> 2920 t/s (2.10x). 22 focused CUDA tests green. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…hed embed) Updates the Gemma 4 E4B CUDA row for the #141 prefill work (int8 tensor-core MMQ, flash-attention prefill, batched embed; ~1.8× at ~1K ctx, ~2.05× at 1.8K). Also fixes a stale kernel name (llm_mmq_q8_0_naive → llm_mmq_q8_0) in CudaMmqQ8_0Tests doc (pr-review-toolkit finding). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces significant performance optimizations for Gemma 4 prefill and decode paths, including a memory-efficient flash-attention prefill kernel, an int8 tensor-core MMQ matmul, a dp4a-based Q8_0 decode matvec, and batched embedding lookups. The code review feedback suggests several high-impact optimizations: replacing runtime integer division and modulo operations with bitwise shifts and masks in the flash-attention kernel, consolidating dual 8-bit loads into single 16-bit loads for scale factors across multiple CUDA kernels, fixing a late overflow check in MatMulBatchedMmq, and utilizing CollectionsMarshal.AsSpan to avoid allocations when processing token lists.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| for (int idx = tid; idx < kt_tile * hd2; idx += (int)blockDim.x) { | ||
| int kk = idx / hd2, pr = idx - kk * hd2; | ||
| unsigned int kh = 0u; | ||
| if (kk < tile_keys) { | ||
| long off = (long)(kt0 + kk) * kv_dim + (long)kv_head * head_dim + 2 * pr; | ||
| kh = sharpi_f32x2_to_f16x2(k_cache[off], k_cache[off + 1]); | ||
| } | ||
| sKh[idx] = kh; | ||
| } |
There was a problem hiding this comment.
Since head_dim (and thus hd2) is always a power of two in LLMs (e.g., 256, 512), we can completely avoid the expensive runtime integer division and modulo operations by using bitwise shifts and masks. This significantly improves the performance of the key loading phase in the flash attention prefill kernel.
int shift_k = __ffs(hd2) - 1;
int mask_k = hd2 - 1;
for (int idx = tid; idx < kt_tile * hd2; idx += (int)blockDim.x) {
int kk = idx >> shift_k;
int pr = idx & mask_k;
unsigned int kh = 0u;
if (kk < tile_keys) {
long off = (long)(kt0 + kk) * kv_dim + (long)kv_head * head_dim + 2 * pr;
kh = sharpi_f32x2_to_f16x2(k_cache[off], k_cache[off + 1]);
}
sKh[idx] = kh;
}| for (int idx = tid; idx < kt_tile * head_dim; idx += (int)blockDim.x) { | ||
| int kk = idx / head_dim, d = idx - kk * head_dim; | ||
| sV[idx] = (kk < tile_keys) | ||
| ? v_cache[(long)(kt0 + kk) * kv_dim + (long)kv_head * head_dim + d] | ||
| : 0.f; | ||
| } |
There was a problem hiding this comment.
Similarly to the key loading phase, head_dim is always a power of two, so we can replace the runtime integer division and modulo with bitwise operations to optimize the value loading phase.
int shift_v = __ffs(head_dim) - 1;
int mask_v = head_dim - 1;
for (int idx = tid; idx < kt_tile * head_dim; idx += (int)blockDim.x) {
int kk = idx >> shift_v;
int d = idx & mask_v;
sV[idx] = (kk < tile_keys)
? v_cache[(long)(kt0 + kk) * kv_dim + (long)kv_head * head_dim + d]
: 0.f;
}| unsigned int dlo = sharpi_byte_at(weights, b0); | ||
| unsigned int dhi = sharpi_byte_at(weights, b0 + 1); | ||
| float dw = sharpi_fp16_to_fp32(dlo | (dhi << 8)); |
There was a problem hiding this comment.
Loading the 16-bit fp16 scale d using two separate sharpi_byte_at calls is inefficient. Since b0 is always 2-byte aligned, we can perform a single 16-bit load by casting the pointer to const unsigned short*, which compiles to a single ld.global.u16 instruction.
unsigned int d_bits = *reinterpret_cast<const unsigned short*>(reinterpret_cast<const char*>(weights) + b0);
float dw = sharpi_fp16_to_fp32(d_bits);| if (tid < MMQ_BM) { long wb = ((long)(row_block + tid) * nb + (KB)) * 34L; \ | ||
| rWd = (row_block + tid < rows) ? sharpi_fp16_to_fp32(sharpi_byte_at(weights, wb) | (sharpi_byte_at(weights, wb + 1) << 8)) : 0.f; } \ |
There was a problem hiding this comment.
We can optimize the scale loading inside the MMQ_LOAD_TILE macro by performing a single 16-bit load instead of two sharpi_byte_at calls.
if (tid < MMQ_BM) { long wb = ((long)(row_block + tid) * nb + (KB)) * 34L; \
rWd = (row_block + tid < rows) ? sharpi_fp16_to_fp32(*reinterpret_cast<const unsigned short*>(reinterpret_cast<const char*>(weights) + wb)) : 0.f; } \| unsigned int dlo = sharpi_byte_at(weights, b0 + 0); | ||
| unsigned int dhi = sharpi_byte_at(weights, b0 + 1); | ||
| float d = sharpi_fp16_to_fp32(dlo | (dhi << 8)); |
There was a problem hiding this comment.
| int qN = (int)((long)cols * nTok); | ||
| if ((long)cols * nTok > int.MaxValue) | ||
| throw new InvalidOperationException( | ||
| $"MatMulBatchedMmq: cols*nTok ({(long)cols * nTok}) exceeds int range."); |
There was a problem hiding this comment.
The overflow check (long)cols * nTok > int.MaxValue is performed after casting the value to int on line 1824. If the value exceeds int.MaxValue, qN will already contain an overflowed/truncated value before the check runs. Performing the check on the long value first is safer and cleaner.
long totalElements = (long)cols * nTok;
if (totalElements > int.MaxValue)
throw new InvalidOperationException(
$"MatMulBatchedMmq: cols*nTok ({totalElements}) exceeds int range.");
int qN = (int)totalElements;| int[] ids = tokens as int[] ?? System.Linq.Enumerable.ToArray(tokens); | ||
| var idTensor = _gpu.UploadRaw( | ||
| System.Runtime.InteropServices.MemoryMarshal.AsBytes<int>(ids), | ||
| TensorShape.D1(N), DType.Float32); |
There was a problem hiding this comment.
To avoid array allocations when tokens is a List<int>, we can use CollectionsMarshal.AsSpan to get a direct span of the list's underlying array. This completely avoids allocations for both int[] and List<int> collections.
ReadOnlySpan<int> idsSpan = tokens is int[] arr ? arr :
tokens is System.Collections.Generic.List<int> list ? System.Runtime.InteropServices.CollectionsMarshal.AsSpan(list) :
System.Linq.Enumerable.ToArray(tokens);
var idTensor = _gpu.UploadRaw(
System.Runtime.InteropServices.MemoryMarshal.AsBytes<int>(idsSpan),
TensorShape.D1(N), DType.Float32);|
Reviewed the Gemini code-assist suggestions — all 7 are legitimate medium-priority micro-opts. Deferred to #145 rather than amending a CI-green, tested, review-toolkit-clean PR at merge time (the two shift/mask suggestions also bake a power-of-two |
Summary
Closes the Gemma 4 E4B CUDA performance gap to llama.cpp on prefill (issues #141 / #142), on an RTX 4070 Ti with
gemma-4-E4B-it-Q8_0.The two foundational commits (
#141cuBLAS-GEMM prefill +#142dp4a/graphs decode) landed first (109→1564 prefill, 49→59 decode). The rest of this branch then closed most of the remaining prefill gap after profiling revealed the real bottleneck.The key finding
The issue framed prefill as matmul-bound, but profiling at a realistic prompt length (N=1848) showed attention was 67% of prefill — it's O(n²) and the old "~30%" note was a short-prompt artifact — and memory-bound on redundant K/V reads (the scalar per-query kernel re-streamed each query's K/V window up to ~512× for SWA layers; ~79× off compute peak). So the lever was attention, not the matmul.
What's in the branch
Prefill (#141):
llm_mmq_q8_0— inline-PTXmma.sync.m16n8k32.s8, Q8_0 weights × Q8_1 activations, register-prefetch double-buffered. Reads each weight once as int8 with no fp16 HBM dequant temp; beats cuBLAS GEMM (matmul 332→316ms). Default-on (SHARPI_PREFILL_MMQ=0reverts).llm_flash_attn_prefill_f32— shared K/V tiles reused across a 16-query tile + online softmax (FA2 style), half2 fp16x2 QK dot via inline PTX (K staged fp16, V fp32 for exact PV). Attention 929→245ms (3.8×). Default-on (SHARPI_PREFILL_FLASH=0reverts).llm_embed_lookup_q8_0_batched— collapses 2·N host launches into one (embed 62→17ms).Decode (#142): dp4a/Q8_1 int8 matvec + CUDA graphs default-on.
Net: Gemma 4 prefill ~1389 → 2920 t/s (2.05×) this session; 109 → 2853 across the whole branch. Decode untouched this session.
Correctness
MMQ and flash are argmax-stable, not bit-exact (int8 / fp16 rounding) — by design, consistent with the existing cuBLAS-GEMM prefill posture. The bit-exact parity oracles deliberately pin them off; new tolerance oracles cover the fast paths:
CudaMmqQ8_0Tests(int8 dot vs fp32DotQ8_0)CudaFlashAttnTests(vs both scalar attention kernels: GQA, head_dim 256/512, windowing, partial tiles)Gemma4_E4B_BatchedPrefill_{Mmq,FlashAttn}MatchesSequential(e2e argmax-stability)22 focused CUDA tests green; the full CUDA suite (120) was green before the final additive (batched-embed) commit. Reviewed with pr-review-toolkit (code-reviewer + silent-failure-hunter): no critical/important findings.
Remaining gap
Prefill is now well-balanced (matmul 45% / attention 36% / embed+ple 19%); closing the rest needs a full tensor-core flash at d=512 (occupancy-limited even in llama.cpp) and decode-side work (dp4a qs-alignment).
🤖 Generated with Claude Code