Skip to content

Gemma 4 CUDA perf: collapse per-token kernel launches + enable batched prefill #136

@pekkah

Description

@pekkah

✅ RESOLVED (2026-06-14)

The headline launch-collapse + batched-prefill work shipped: PLE projection fused (llm_ple_proj_norm), batched-trunk prefill enabled for Gemma 4 (#137), and CUDA-graph decode capture/replay made real and default-on (#139/#142SHARPI_CUDA_GRAPH=0 reverts), plus the attnScale and dual-HeadNorm fusions. Prefill/decode then moved to cuBLAS GEMM / dp4a (#141/#142). Decode is now HBM-bandwidth-bound (matvecs ~90% HBM), so the two remaining P1 micro-fusions (layer_output_scale, EmbeddingScale bake) are moot — there is no launch-count lever left. Remaining prefill compute lever lives in #152. Closing.


Performance review of the Gemma 4 CUDA execution path (CudaForwardPass.ForwardGemma4, CudaHybridForwardPass.GpuLayerGemma4/Prefill).

Key finding: Gemma 4 decode is launch-bound, not compute-bound. The inner loop issues ~28 kernel launches/layer (~950/token over ~34 layers) plus ~200 upfront for PLE projection prep — ~1,100–1,200 kernel launches per token. At ~4–6 µs host launch cost that's ~5–7 ms/token of pure host overhead (effective ceiling ~150–200 t/s before any compute). PLE alone is ~half the launches. Numerics (SWA bounding, dual-RoPE, KV-share, per-layer head_dim) are already correct — the wins are in launch count and prefill batching.

The recent GDN/MoE/MTP perf work is architecturally irrelevant to Gemma 4 (dense, no SSM, no experts), but two pieces transfer: the non-owning View slice type (#111) and the batched-trunk GEMM-N machinery (#123, currently gated off for Gemma 4).

Todo

P0 — Collapse PLE launches (biggest decode win)

  • Replace PLE staging CopyDeviceRegion round-trips with CudaBackend.View(parent, elemOffset, elemCount) (added in perf(engine,cuda): batch the GDN-hybrid prefill trunk (GEMM-N attn/GDN projections) — trunk is now ~62% of prefill (#110 follow-up) #111). The "Tensor can't encode a device-pointer offset" comments at CudaForwardPass.cs:1122-1126 and 1156-1158 are stale. Zero numeric change.
  • Replace the L-iteration per-slice loop in BuildPerLayerProjectionsGpu (CudaForwardPass.cs:1128-1140) with a single fused batched kernel (llm_ple_proj_norm: per-pleWidth-slice RmsNorm + add PLE-row slice + scale 1/√2 over all L rows). ~200 launches/token → 1.
  • Hybrid path: replace the per-layer H2D UploadInto of the PLE slice (CudaHybridForwardPass.cs:1803-1804) with a View into a once-uploaded _projPerLayer.

P0 — Enable batched-trunk prefill for Gemma 4 (biggest TTFT win)

  • Lift the !_isGemma4Like gate in IsBatchedPrefillSupported (CudaHybridForwardPass.cs:975-982). Blockers to resolve:
  • Thread per-layer LayerHeadDim[i] through the GEMM-N trunk shapes (decode path already proves per-layer View slicing works).
  • Add batched RoPE variants for dual-RoPE: RoPEWithFactors (global) and plain-θ (SWA). Current batched RoPE kernel is NEOX-only.
  • Batched PLE over N tokens (reuses the P0 fused projection kernel with a batch dim).

P1 — Fold scalar fix-up kernels into neighbors

  • Add explicit attnScale param to Attention/AttentionSwa (pass 1.0f for Gemma 4) and delete the Q prescale ScaleInPlace(qView, √layerHd) (CudaForwardPass.cs:1018, CudaHybridForwardPass.cs:1746). ~34 launches/token gone.
  • Fuse layer_output_scale (CudaForwardPass.cs:1064-1065) into the preceding PLE/post-FFN AddInPlace as a fused scale-add.
  • Bake EmbeddingScale (√embDim, :936-937) into embedding dequant or first RmsNorm weight at load.

P1 — Make the hybrid command-buffer real (CUDA Graphs)

  • Implement BeginRecord/RecordBarrier/EndRecordAndSubmit (currently no-ops at CudaBackend.cs:685-688) as real CUDA Graph capture + replay with updatable kernel-node params for the changing position/token. Structural fix for the ~1,000-launch/token problem.

P2 — Minor fusions

  • Fuse dual Q+K HeadNorm (CudaForwardPass.cs:974-978) into a single two-input kernel (MatVecDual-style). ~34 launches/token.

🤖 Generated with Claude Code


Update (batched prefill landed): all-GPU batched-trunk prefill for Gemma 4 shipped. Bonus primitives added along the way: llm_matvec_q8_0_gemm_n (the on-disk E4B is Q8_0, previously unbatchable) and llm_gelu_tanh_mul_strided (fully-batched PLE without a transpose). All new batched kernels have bit-exact oracle tests; end-to-end Gemma4_E4B_BatchedPrefill_MatchesSequential confirms parity vs the per-token loop. Follow-ups: (a) >4096 context uses the per-token fallback (global layers need AttentionBatchedWave; SWA is window-bounded so fine); (b) SnapKV-active prefill falls back (no Q-capture in the batched trunk); (c) Q8_0 GEMM-N re-reads weights per token — weight-reuse tiling is a perf follow-up.

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions