perf(cuda): dense Q4_K batched prefill + decode SoA — Qwen3-8B 61.2→432 prefill, 65→74.7 decode (#156/#158/#160)#161
Conversation
Open the all-GPU batched-trunk prefill (originally Gemma-4-only, #136) to any dense model the batched kernels cover — e.g. Qwen3-8B Q4_K. Three Gemma assumptions in the batched layer body are generalized: - LayerHeadDim null-coalesce to _headDim (non-Gemma has no per-layer head_dim) - FFN activation dispatch on hp.FfnActivation (SiLuMul for SwiGLU vs GeluTanhMul) - cached _attnScale (1f for Gemma, -1f else so the kernel derives 1/sqrt(head_dim)) The gate (IsGemma4BatchedPrefillSupported -> IsBatchedPrefillSupported) drops the _isGemma4Like requirement but keeps every real guard (MoE, TQ, attn-bias, non-NEOX RoPE, L2 QK-norm, batchable dtype). Methods renamed to drop the Gemma4 suffix. Fix a latent ordering bug exposed by the generalization: the batched body applied QK-norm before RoPE, but the dense per-token Forward applies RoPE before QK-norm. RoPE does not commute with per-channel-weighted RMSNorm, so this diverged ~9 logits on Qwen3. RoPE/QK-norm are now ordered to match the matching per-token oracle (Gemma norm->rope; dense rope->norm) and the batched RoPE honors NoRopeLayerStep. New Qwen3CudaBatchedPrefillTests: default path argmax-stable, flash-off bit-exact vs the per-token loop. Gemma4CudaBatchedPrefillTests still 5/5 (no regression). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…batched prefill Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Generalize the Gemma 4 decode CUDA-graph capture/replay (#136) to the non-Gemma per-token CudaForwardPass.Forward, so dense Q4_K models (Qwen3-8B etc.) get the same host-launch-overhead reduction. Decode is HBM-bandwidth-bound, so the win is killing the ~1k launches/token, not kernel throughput. - Refactor Forward into token-varying prefix (embed) -> RunDeviceRegion (layer loop + final norm + output projection) -> TQ ring-advance + logits download. The TQ ring-advance (host state) and SnapKV Q-capture (host-varying device offset) stay outside the captured region. - Wrap RunDeviceRegion with the generic capture helpers via the new TryRunDeviceRegionViaGraph, mirroring TryRunGemma4DeviceRegionViaGraph. Bail on _tqEnabled / _kvEvictedCount>0 / _snapKvCaptureSlot>=0 / _isMoE (each breaks static topology or does an illegal mid-capture sync). - Pre-grow the Q4_K/Q8_0 dp4a Q8_1 input scratch to max(embDim,intermDim) before first capture (capture forbids cudaMalloc). - New Qwen3CudaGraphParityTests: graph replay is bit-identical to direct launch over 8 decode steps, plus the SnapKV-configured-no-evict guard. Gemma4 graph tests + Qwen3 batched-prefill tests still green. Qwen3-8B Q4_K -g -1 decode: 65 -> 70 t/s (+7%, same-session A/B; SHARPI_CUDA_GRAPH=0 to bisect). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Route Q4_K trunk matmuls in the batched prefill through the compute-bound dequant->fp16->cuBLAS GEMM (weight read once per batch) instead of the memory-bound matvec GEMM-N (weight re-streamed once per token). Profiling Qwen3-8B Q4_K prefill confirmed the Item-C gate: the trunk matmuls are 99.5% of layer time (attn ~0.5% after flash), and the matvec path is weight-bandwidth bound (~4.5GB x N_tokens of HBM traffic). - New llm_dequant_q4k_to_f16 kernel: decodes the Q4_K super-block (d*sc*nibble - dmin*mn, identical to llm_embed_lookup_q4k / llm_matvec_q4k) to fp16, one 256-thread block per weight row. - Generalize MatMulBatchedGemm from Q8_0-only to Q8_0|Q4_K: dtype-select the dequant kernel + dtype-aware cols alignment (Q4_K super-block = 256). - GpuMatMulBatched dispatches Q4_K to the GEMM path (cols % 256), with a defensive matvec fallback otherwise. - New parity oracle Qwen3_8B_BatchedPrefill_Q4KGemm_ArgmaxStable: GEMM vs bit-exact matvec, argmax-equal + top-5 overlap (fp16, not byte-exact, same contract as the Q8_0 #141 GEMM). FlashOff_MatchesSequential pinned to PrefillGemmEnabled=false to keep its bit-exact-matvec purpose. - Prefill-profile gains a matmul= breakdown alongside attn=. Qwen3-8B Q4_K -g -1 prefill: 119.8 -> 432 t/s (3.6x, same-session A/B; SHARPI_PREFILL_GEMM=0 to bisect); --no-thinking 61.8 -> 427. Decode and the --tq rows (batched prefill disabled for TQ) unchanged. llama.cpp b8585 pp1008 = 5764 t/s; the remaining ~13x is its int8 Q4_K MMQ kernel (Item C2 follow-up). README Qwen3-8B rows updated. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Adds llm_mmq_q4k, the maximal Item C: each Q4_K weight is read once as int8 (nibble-expanded, get_scale_min_k4 decode) and fed to the m16n8k32 s8 mma with no fp16 dequant temp to HBM — the cost that capped C1's dequant→fp16→cuBLAS GEMM. The kernel mirrors llm_mmq_q8_0's tiling/fragment map byte-for-byte (validated by #141); the only Q4_K-specific work is the 4-bit→int8 weight expansion, the per-(row,sub-block) (scale,min) unpack, and the asymmetric min-bias term −super_dmin·mn·(d_a·Σq_a). That activation sum d_a·Σq_a is now packed as the fp16 `s` half of each q8_1 block by llm_quantize_q8_1 (every other reader masks the d-word with 0xffff, so the high half was inert — mirrors ggml block_q8_1's ds). Wired as the default Q4_K prefill path under PrefillMmqEnabled (cols%256, weight read once as int8); SHARPI_PREFILL_MMQ=0 reverts to the C1 GEMM, =…_GEMM=0 to the bit-exact matvec. Argmax-stable, not bit-exact (both operands int8-quantized + the min-bias rounds through fp16 s). Qwen3-8B same-session A/B vs C1: +25% at ~100-tok prompts (284→355 t/s, where C1 still pays its fp16-temp write), converging to a tie by ~1K ctx (430→432 @1008) as cuBLAS amortizes that temp — so the 1K README column is unchanged but short-context prefill and prefill VRAM both improve. The remaining gap to llama.cpp's MMQ is its cp.async pipelining, which hides the weight re-read across token tiles that cuBLAS amortizes via L2. Tests: CudaMmqQ4KTests (GPU MMQ vs CPU DotQ4K fp32 ref, 0 mismatches — isolates the nibble/scale decode + min-bias + fragment map) and Qwen3_8B_BatchedPrefill_Q4KMmq_ ArgmaxStable (MMQ vs bit-exact matvec GEMM-N, argmax + top-5). 220/220 ForwardPass tests green (the shared quantize_q8_1 change is parity-safe across all q8_1 paths). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Qwen3-8B Q4_K CUDA decode matvec hit only ~74% of HBM peak vs the Q8_0 dp4a path's ~90%. The gap is NOT the #149 funnelshift — the 144-B Q4_K super-block is already 16-byte aligned. It is per-super-block COMPUTE: the get_scale_min_k4 6-bit (scale,min) unpack switch forms a dependent chain that starves memory-level parallelism so loads never saturate HBM. Fix: a one-time scale-pre-unpacked SoA repack at upload (RepackQ4KSoa / llm_q4k_repack_soa) splits each block into [quants][unpacked scale/min bytes][d|dmin] regions. The decode matvec (llm_matvec_q4k_soa) and prefill MMQ (llm_mmq_q4k_soa) then read plain bytes — no switch. The stored scale/min integers are identical, so both kernels are bit-identical to the interleaved versions (8-warp reduction kept to preserve FP order). Same-session A/B (RTX 4070 Ti, opt-in SHARPI_Q4K_SOA=1): matvec BW 74% -> 89% (+13-15%); decode 70.0 -> 74.7 t/s (+7%); prefill +5% (MMQ benefits too); bit-identical, coherent output. Dense-only: repack gated on !_isMoE; the unconverted Q4_K readers (GEMM-N, dequant->GEMM, MTP N2) throw if a SoA handle reaches them. Default-on is blocked on converting the fragile MTP N2 byte-parity path. Tests: CudaQ4KSoaTests (decode + MMQ bit-identical + A/B), CudaDecodeMatvecQ4KRooflineProbe. 26 targeted CUDA correctness tests green. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…aders (#160) Flip SHARPI_Q4K_SOA to default-on (set =0 to revert) so every dense Q4_K CUDA model gets the scale-pre-unpacked decode win (Qwen3-8B 70.0 → 74.7 t/s, +7% decode / +5% prefill) without an opt-in flag. The blocker was the three Q4_K readers that threw on a SoA weight handle. All three now have bit-identical SoA twins: - llm_matvec_q4k_n2_soa — the N=2 MTP batched-verify reader (dense Qwen3.6-27B-MTP). Same two-input dp4a + FP accumulation order, NWARPS=8, only scale/min come from the pre-unpacked SoA bytes. The fragile MTP llama.cpp byte-parity oracle (MtpDecoder_GreedyParity_LlamaCpp) still holds — cumulative trunk drift over 64 layers is unchanged because the arithmetic order is byte-identical. - llm_matvec_q4k_gemm_n_soa + llm_dequant_q4k_to_f16_soa — the SHARPI_PREFILL_MMQ=0 fallback prefill readers, so the flag is safe with MMQ off. Backend auto-routes per repacked handle (DispatchMatVecQ4KN2/Batched gain a soa param); the MoE repack stays gated !_isMoE at upload. Validation: 7 CudaQ4KSoaTests bit-identical (decode/MMQ/N2/GEMM-N/dequant) + MTP byte-parity + 13 HybridGdn MTP + 6 Qwen3 CUDA parity + SnapKV/MMQ-E2E green. Same-session A/B: default (env unset) 74.7 t/s vs SHARPI_Q4K_SOA=0 69.7 t/s. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a scale-pre-unpacked Structure of Arrays (SoA) layout repack and optimized kernels for Q4_K weights on CUDA, improving decode and prefill performance for dense models like Qwen3-8B. It also generalizes the all-GPU batched-trunk prefill to non-Gemma dense models, adds CUDA graph capture/replay for non-Gemma dense decode regions, and includes comprehensive benchmark scripts and parity tests. The reviewer noted that several of the newly added Q4_K SoA and dequantization kernels are missing from ForceEagerJit(), which should be added to prevent JIT compilation stutters during the first inference run.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| _matvecQ80GemmNKernel, _mmqQ80Kernel, _mmqQ80SoaKernel, _mmqQ4kKernel, | ||
| _matvecQ80Dp4aSoaKernel, _q80RepackSoaKernel, | ||
| _matvecQ80SoaKernel, _matvecQ80GemmNSoaKernel, _dequantQ80F16SoaKernel, |
There was a problem hiding this comment.
The newly introduced Q4_K SoA and dequantization kernels (_mmqQ4kSoaKernel, _q4kRepackSoaKernel, _matvecQ4KSoaKernel, _matvecQ4KN2SoaKernel, _matvecQ4KGemmNSoaKernel, _dequantQ4KF16Kernel, and _dequantQ4KF16SoaKernel) are missing from ForceEagerJit(). Adding them here ensures they are eagerly compiled during initialization, preventing JIT compilation stutters during the first inference run.
_matvecQ80GemmNKernel, _mmqQ80Kernel, _mmqQ80SoaKernel, _mmqQ4kKernel, _mmqQ4kSoaKernel,\n _matvecQ80Dp4aSoaKernel, _q80RepackSoaKernel, _q4kRepackSoaKernel,\n _matvecQ80SoaKernel, _matvecQ80GemmNSoaKernel, _dequantQ80F16SoaKernel,\n _matvecQ4KSoaKernel, _matvecQ4KN2SoaKernel, _matvecQ4KGemmNSoaKernel,\n _dequantQ4KF16Kernel, _dequantQ4KF16SoaKernel,|
Remaining Qwen3-8B Q4_K gap to llama.cpp (decode non-matvec cost + prefill cp.async MMQ) tracked as follow-up #162. |
ForceEagerJit() was missing the Q4_K SoA repack + decode/N2/GEMM-N/dequant readers (and the pre-existing AoS dequant), so their SASS was finalized lazily on first decode instead of at load. Add them so first-token latency pays no per-kernel JIT stutter. No behavior change. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
_diff_*.txt were scratch files a review agent wrote during PR #161 review and got swept in by `git add -A`. Not project files; remove. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
Addressed the gemini-code-assist suggestion in |
Generalizes the GDN-hybrid batched-prefill machinery to dense Q4_K models and closes most of the CUDA decode/prefill gap on Qwen3-8B. Builds across #156 (batched prefill → compute-bound GEMM → int8 MMQ → decode SoA), #158 (decode CUDA graphs), and #160 (SoA default-on).
Headline (Qwen3-8B Q4_K_M, RTX 4070 Ti, ~1K ctx, same-session A/B):
What's in here
#156 — dense Q4_K prefill
MatMulBatchedGEMM-N path) to dense Q4_K (1b4bb4c)50bac9d)llm_mmq_q4k,mma.m16n8k32.s8) — weight read once as int8, no fp16 HBM temp;SHARPI_PREFILL_MMQ=0reverts to C1 (cb72245)get_scale_min_k4unpack switch — matvec 74% → 89% of HBM peak (0d5fe20)#158 — decode CUDA graphs for the non-Gemma dense Forward (capture/replay the per-token device region), 65 → 70 t/s;
SHARPI_CUDA_GRAPH=0to bisect (e1de163)#160 — SoA default-on (
621184f): convert the Q4_K readers that previously threw on a SoA handle to bit-identical SoA twins, then flipSHARPI_Q4K_SOAto default-on (=0reverts). MoE Q4_K stays interleaved (repack gated!_isMoE).CudaForwardPass, the dense non-GDN pass, repacks):llm_matvec_q4k_gemm_n_soa+llm_dequant_q4k_to_f16_soa— theSHARPI_PREFILL_MMQ=0fallback prefill readers. Both bit-identical (tested). These are what default-on actually required (decode matvec + MMQ prefill SoA shipped in0d5fe20).llm_matvec_q4k_n2_soa(MTP batched-verify). The dense-MTP CUDA pass isCudaHybridGdnForwardPass, which does not repack to SoA today, so no production path sends a SoA handle toMatMulN2— the old throw was unreachable and this reader is not yet on a live path. It is bit-identity-tested in isolation (Q4KSoaN2_BitIdenticalToInterleaved,maxAbs==0) so a future wiring of SoA into the GDN pass works out of the box. (Review correction: an earlier description called N2 "the blocker" — it was not.)f8a262c): eager-JIT the new + pre-existing Q4_K SoA kernels so first-decode pays no JIT stutter (gemini-code-assist).Correctness
SHARPI_PREFILL_GEMM=0reverts to bit-exact); the SoA repack is bit-identical.CudaQ4KSoaTestsmaxAbs==0bit-identity oracles (decode/MMQ/N2/GEMM-N/dequant). The SoA conversion preserves FP accumulation order exactly (same NWARPS=8) — that's why bit-identity holds.MtpDecoder_GreedyParity_LlamaCppstays green but runs onCpuBackend— it guards the CPU MTP path and is unaffected by the CUDASHARPI_Q4K_SOAflag (it does not exercise the CUDA N2 SoA reader).CudaMmqQ4KTests,Qwen3CudaBatchedPrefillTests,Qwen3CudaGraphParityTests,CudaDecodeMatvecQ4KRooflineProbe, 13 HybridGdn MTP, SnapKV/MMQ-E2E — all green.returnsilently without a GPU + the local.ggufmodels, so the green check does not cover them; themaxAbs==0bit-identity and same-session A/B (default 74.7 vs=069.7 t/s) were verified on a local RTX 4070 Ti.Remaining gap → #162
llama.cpp b8585 is pp1008 ~5764 / tg128 ~78–91. Prefill gap is its cp.async-pipelined MMQ; decode is near the HBM ceiling (~89%), so the next lever is the non-matvec decode cost (attention/RoPE/norms/sampler), not the matvec. Test-hardening follow-ups (default-on A/B test, MoE-skip negative test) also tracked in #162.
🤖 Generated with Claude Code
Closes #156, closes #158, closes #160.