perf(gemma4-cuda): TC flash-attention prefill (#146/#147) + SoA Q8_0 weight layout (#149) — prefill 2853→4240#148
Conversation
…ve (#146) Single-warp NVRTC test kernel + host wrapper + unit test proving the mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 fragment layouts (A 16x16 row-major, B 16x8 col-major, C 16x8 fp32) against a CPU fp16-rounded reference. maxAbs 4.77e-7, 0/128 mismatches — the A/B/C lane->register maps are correct. This de-risks the hardest unknown for the full TC flash-attention prefill: a wrong fragment map silently produces garbage. Reusable building block for the QK^T and P.V mma stages. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Full TC version of the half2 flash kernel: both QK^T and P.V run on the mma cores (mma.sync m16n8k16, fp16 multiplicands / fp32 accumulate), one warp per 16-query tile, online softmax. Key design points: - The QK^T score C-fragment is reused directly as the P.V A-fragment with no transpose — the key index is QK^T's N-col and P.V's contraction dim, and the m16n8k16 C and A fragment layouts coincide on (row, 2*tig). - O[16 x head_dim] is too large for registers (256 regs/lane at d=512), so it lives in shared fp32 and is rescaled in place per key-tile; K and V time-share one 16 x head_dim fp16 buffer, fitting the whole kernel in 48 KB shared at d=512. - Guards the masked-tile online-softmax NaN (a later query's early tiles fall entirely outside its sliding window -> mnew=-inf -> exp(-inf+inf)). Parity (CudaFlashAttnTcTests): maxAbs ~3-4e-4 vs the scalar batched kernels across GQA, both Gemma 4 head_dims (256 SWA / 512 global), windowing, partial tiles — more accurate than the half2 kernel. End-to-end FlashTcMatchesSequential green. A/B vs the half2 baseline (all-GPU, warm, gemma-4-E4B Q8_0, 4070 Ti): N=965: 2585 -> 2732 t/s (+5.7%) N=2054: 2716 -> 2849 t/s (+4.9%) The win is occupancy-limited (single warp/block + 48 KB shared -> ~2 warps/SM); a multi-warp / d-split version is the path to a larger gain. Opt-in for now (SHARPI_PREFILL_FLASH_TC=1, gated on head_dim % 16 == 0) until that lands. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a tensor-core flash-attention prefill implementation (Issue #146) utilizing mma.sync.aligned.m16n8k16 instructions. It adds the CUDA kernels llm_mma_test_m16n8k16_f32 and llm_flash_attn_prefill_tc, exposes them via CudaBackend, integrates the tensor-core prefill option into CudaForwardPass (controlled by the SHARPI_PREFILL_FLASH_TC environment variable), and includes benchmarking scripts and comprehensive unit tests to validate correctness against scalar references. No review comments were provided, so there is no additional feedback to address.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
…147) Fixes the single-warp #146 occupancy limit. A block is now W=4 warps cooperating on one 16-query tile with the head dim split across them: warp w owns output columns [w*dW, ...) (dW = head_dim/W), so O[16 x dW] is REGISTER-resident (64 regs/lane at d=512) instead of in shared — no per-key-tile shared-O rescale, and the freed shared lifts occupancy ~2 -> ~16-20 warps/SM. Each warp computes a PARTIAL QK^T over its d-slice; the partials sum across warps through a small shared S buffer ([W x 16 x 16] fp32), after which every warp holds the full reduced score tile and proceeds like the single-warp kernel (no-transpose score->P, P.V for its slice). Requires head_dim % 64 == 0 (W*16); the #146 single-warp kernel is the head_dim % 16 fallback. Parity (CudaFlashAttnTcTests.TC2): maxAbs identical to the single-warp kernel and the scalar reference (~3-4e-4) across GQA / 256-SWA & 512-global / windowing / partial tiles — the d-split reduction is numerically equivalent. A/B vs half2 (all-GPU, warm, gemma-4-E4B Q8_0, 4070 Ti): N=965: 2582 (half2) -> 2691 (tc1) -> 3269 (tc2) = +26.6% N=2054: 2732 (half2) -> 2851 (tc1) -> 3818 (tc2) = +39.8% (win grows with ctx) Now default on (SHARPI_PREFILL_FLASH_TC=0 reverts to half2; =1 + _TC1=1 forces the single-warp kernel for A/B). Pinned off in the bit-exact / half2 Gemma4 oracles. All 23 focused Gemma4Cuda/TC/MMQ/primitive tests green. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Update: multi-warp / d-split (#147) landed — default-on, +27-40%The single-warp kernel above was occupancy-limited (+5%). The follow-up
Parity identical to scalar (~3-4e-4). Now default-on ( |
After the TC flash work moved the prefill bottleneck to the matmul/FFN GEMMs (profiling: ~53% of prefill), this probe times MatMulBatchedMmq at FFN-shaped GEMMs and reports achieved int8 TOPS vs the ~160 TOPS dense peak: ffn-gate/up [8192x2048]: 53.7 TOPS (34%) ffn-down [2048x8192]: 36.4 TOPS (23%) <- worst, occupancy-starved qkv [6144x2048]: 52.2 TOPS (33%) MMQ is at 23-34% of TC peak (~3x headroom), matching the 2.2x end-to-end gap to llama.cpp. ffn-down (few output rows over a long contraction) is occupancy-bound the same way the single-warp flash kernel was — split-K + better tiling is the next attack. Diagnostic only (asserts true; run via --filter). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…149, isolated) The interleaved Q8_0 block (34 B = 2 B fp16 scale + 32 int8) puts the qs at a 2-byte offset, so every weight word in the MMQ load costs an extra global load + __funnelshift (sharpi_uint_at). llm_mmq_q8_0_soa reads the weights from a struct- of-arrays repack instead — quants 16 B-aligned (8 uint/block), fp16 scales separate — so each weight word is a plain aligned load. Only the weight path changes; activations (Q8_1) are already aligned. Validated in isolation (CudaMmqSoaTests): - BIT-IDENTICAL to the interleaved MMQ (maxAbs 0, 0 diffs) across 5 shapes incl. ffn-down [2048x8192] and qkv. - Speed at REAL prefill nTok=2048 (not the 1024 that mismeasured split-K): ffn-gate/up 54.6 -> 61.4 TOPS (+11.1%) ffn-down 39.6 -> 49.1 TOPS (+19.4%) qkv 54.6 -> 60.5 TOPS (+9.7%) A per-instruction-efficiency win (not occupancy), so it should translate end-to-end. Next: model-wide SoA migration (upload repack + the decode matvec / embed / dequant Q8_0 readers, since each weight is shared by prefill MMQ and decode). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…prefill (#149) Repacks 2-D Q8_0 GEMM weights into the SoA layout at upload (mmqSoa flag / SHARPI_MMQ_SOA, default off), so the prefill MMQ and decode dp4a matvec read the weight with plain aligned loads instead of the 2-byte-misalignment funnelshift. - llm_q8_0_repack_soa: one-time GPU repack interleaved [34B/block] → [quants rows*cols B][scales rows*nb fp16]. CudaBackend.RepackQ8_0Soa allocates the dest, repacks, frees the source, and marks the handle in _soaHandles. - llm_mmq_q8_0_soa (refactored to the AoS signature, scales located internally) and llm_matvec_q8_0_dp4a_soa: aligned-load variants. MatMulBatchedMmq and the decode matvec auto-route to them per repacked handle — no caller changes. - UploadWeight repacks 2-D Q8_0 (norms/biases are 1-D; embedding uploads elsewhere); output.weight auto-routes via the handle set. Verified BIT-IDENTICAL end-to-end (CudaMmqSoaE2ETests: Gemma 4 prefill + 4 decode steps, maxAbs 0) and at the kernel level (CudaMmqSoaTests, +13-18% isolated GEMM). A/B end-to-end (all-GPU, warm, N=2054): prefill 3849 → 4240 t/s (+10.2%), decode neutral. 26 focused tests green (SoA off by default). Opt-in until the remaining Q8_0 readers (dequant, fp32 matvec, gemm_n — used only by the bit-exact/gemm-off oracles) are converted, then it can default on. Follow-up. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Completes the SoA Q8_0 migration so it can default on. Added SoA variants of the three remaining interleaved readers — llm_matvec_q8_0 (fp32 decode), _gemm_n, and llm_dequant_q8_0_to_f16 — each bit-identical to its AoS counterpart (aligned quant bytes + one aligned fp16 scale per block, scales located at byte rows*cols). All five Q8_0 readers (MMQ, dp4a, fp32 matvec, GEMM-N, dequant) now auto-route to the SoA kernel per repacked handle; the embedding stays interleaved (uploaded separately, not repacked). mmqSoa now defaults ON (SHARPI_MMQ_SOA=0 reverts). With every reader SoA-aware the bit-exact Gemma4 oracles exercise the SoA path and pass bit-exactly (a half-migration would have let them falsely pass on garbage — they compare batched-vs-sequential, so both sides must read the format correctly). 26 focused Gemma4/Cuda tests green incl. GemmOff_MatchesSequentialBitExact. End-to-end (all-GPU, warm, gemma-4-E4B Q8_0, 4070 Ti): prefill +10-12% now by default (N=965 3297→3698, N=2054 3849→4240). Prefill ~4240 t/s = ~50% of llama.cpp's ~8475. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…A Q8_0 #149 default-on Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…aHandles lifecycle From the pr-review-toolkit cycle (code-reviewer, silent-failure-hunter, pr-test-analyzer, comment-analyzer). No critical/high defects were found (all 5 Q8_0 readers are SoA-routed; the one non-SoA reader, MatMulN2, throws on Q8_0 rather than reading garbage; NVRTC failures surface loudly). Addressed the actionable items: - Test gap (sev 8): the GPU repack kernel + the fp32-matvec / GEMM-N / dequant SoA readers had no GGUF-free coverage (only bench-machine model oracles). Added GpuRepack_AllSoaReaders_BitIdenticalToInterleaved: repacks on the GPU (production path) and asserts all five readers (dp4a, fp32, GEMM-N, dequant, MMQ) are bit-identical to interleaved across 3 shapes — runs on any CUDA box. - _soaHandles lifecycle: switched HashSet→ConcurrentDictionary (matches the sibling handle tables, removes the latent thread-safety footgun the whole SoA-correctness invariant rested on) and prune it on Free + Dispose (was an unbounded leak across model load/free cycles). - README: balanced a stray parenthesis in the Gemma 4 CUDA prefill cell. 28 focused Gemma4/Cuda tests green. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
pr-review-toolkit cycle (4 agents) — resultsRan code-reviewer, silent-failure-hunter, pr-test-analyzer, comment-analyzer over No critical/high defects. Verified by the agents: all 5 Q8_0 readers are SoA-routed; the one non-SoA reader ( Addressed (commit
Noted, not addressed (lower-value test hardening, optional follow-ups):
These are hardening suggestions, not bugs — deferring. |
Summary
Two stacked, default-on Gemma 4 CUDA prefill wins, each built bottom-up and validated. Prefill 2853 → 4240 t/s (1.8K ctx) = 34% → 50% of llama.cpp's ~8475. All fast paths argmax-stable; the SoA repack is bit-identical. Full ForwardPass suite 382/382 green.
1. Tensor-core flash-attention prefill (#146/#147)
Both QK^T and P·V on the mma cores (
mma.sync.m16n8k16.f32.f16.f16.f32), replacing the scalar O(n²) per-query attention.llm_mma_test_m16n8k16_f32, maxAbs 4.77e-7) — a wrong lane→register map silently produces garbage.(row, 2·tig)); O[16×512] lives in shared (256 regs/lane otherwise). +5%, occupancy-limited.SHARPI_PREFILL_FLASH_TC=0reverts).2. SoA Q8_0 weight layout (#149)
Repacks 2-D Q8_0 weights at upload into
[quants rows*cols B][scales rows*nb fp16]so all readers load the quants 16-byte aligned instead of the interleaved 34-byte block's 2-byte-misalignment__funnelshift.llm_q8_0_repack_soa); all five Q8_0 readers (MMQ, dp4a, fp32 matvec, GEMM-N, dequant) get bit-identical SoA variants and auto-route per repacked handle.CudaMmqSoaE2ETests: prefill + 4 decode, maxAbs 0) — the bit-exact oracles exercise the SoA path (converting every reader is what keeps them meaningful; a half-migration would false-pass on garbage).SHARPI_MMQ_SOA=0reverts.Dead-ends ruled out (documented, reverted)
.cabeats.cgbut both lose).Diagnostics kept
CudaMmqRooflineProbe(int8 TOPS vs peak at FFN shapes) — but probes must use the real prefill nTok.Test plan
dotnet test tests/SharpInference.Tests.ForwardPass --filter "Gemma4Cuda|CudaFlashAttn|CudaMmq"(26 green).🤖 Generated with Claude Code