Skip to content

Prefill optimization attack plan: Gemma 4 on RTX 4070 Ti (12 GB, Ada / sm_89) #247

@pekkah

Description

@pekkah

Status (2026-06-14) — verified against current code

Tracing the actual prefill dispatch (not just the kernel wrappers) shows the headline Tier-1 item #2 is already done:

Still genuinely open (re-verified):

Net: lead with #1 (PLE), then #5/#6; #2 is done.


Goal

Improve prefill (TTFT) throughput for Gemma 4 models on a 12 GB Ada GPU (RTX 4070 Ti, sm_89). This issue captures a ranked attack plan based on profiling the current prefill path and surveying the external state of the art.

Context: what's already solved (do not re-chase)

The prefill path is already mature. The following are implemented and should not be treated as opportunities:

  • Batched all-GPU trunk prefill — CudaForwardPass.PrefillBatchedTrunk (CudaForwardPass.cs:2585)
  • Chunked prefill, 4096-token windows (CudaForwardPass.cs:200)
  • FlashAttention tensor-core Tc2 (multi-warp d-split, +27–40%) — CudaBackend.cs:4234
  • int8 tensor-core MMQ + SoA repack (Q8_0/Q4_K/Q4_0) — CudaBackend.cs:2308
  • fp16 compute-bound GEMM (dequant→fp16→cuBLAS) — CudaBackend.cs:2213
  • KV narrowing (bf16 / Q8_0) — Quantized (fp16/q8_0) KV cache for the dense CUDA path → unlock long context (Gemma 4 12B first) #179
  • Gemma SWA, dual-RoPE, QK-norm, softcap, KV-share, k_eq_v, per-layer head_dim
  • CUDA graphs — decode only (CudaForwardPass.cs:1549, 1858)

FP8 reality check (decision: skip for LLM prefill)

On Ada, fp8 and int8 tensor cores run at the same peak rate (~2× fp16+fp32-accum). Since int8 MMQ already exists, fp8 gives no prefill speed win on this card — its only edge is activation-outlier accuracy at equal throughput. Gemma 4 ships native QAT int4/Q4_0, so the int8 MMQ path is already the right tool. FP4/NVFP4 native tensor-core acceleration is Blackwell-only. FlashAttention-3 (Hopper wgmma/TMA) does not apply to Ada — Tc2 is the correct ceiling here.


Attack plan (ranked by ROI)

0. Measure first

  • Enable s_prefillProfile and capture the [prefill-profile] breakdown (embed / ple / attn / matmul, CudaForwardPass.cs:2635) on the target Gemma 4 variant + a representative prompt length. Confirm PLE and SWA-attention shares before writing kernels.

Tier 1 — Gemma-4-specific, codebase-confirmed bottlenecks

1. PLE pre-pass is CPU-bound (~30% of batched prefill).
Per the in-code note at CudaForwardPass.cs:2693: "this serial gather+dequant of N×stackedDim Q8_0 elements was ~30% of the whole batched prefill (#141 profiling)." BuildPerLayerProjectionsBatched (CudaForwardPass.cs:2682) does CPU Parallel.For dequant → full host→device upload → then the layer loop (fully serialized; tPle is a distinct phase). PLE is unique to Gemma 4, so generic LLM prefill work doesn't cover it.

  • GPU-side PLE gather+dequant — mirror EmbedLookupQ8_0Batched (CudaForwardPass.cs:2607) for the PLE table; index by token id on-device. Removes CPU dequant and the N×stackedDim upload. (E2B/E4B: keep table resident; 12B ~4.2 GB: stream rows / pinned-host UVA gather.)
  • Overlap — double-buffer: dequant chunk k+1's PLE rows while the GPU runs chunk k's layers (separate copy stream + cudaStreamWaitEvent, primitive already at CuBlasInterop.cs:151).
  • Cache dequantized PLE rows by token id (LRU in VRAM) — pure function of token id; big win on repeated tokens / shared prefixes.

2. SWA layers have no tensor-core flash path.
Gemma 3/4 use a 5:1 local:global ratio, so 5/6 of layers are sliding-window. Global layers use FlashAttentionPrefillTc2 (CudaBackend.cs:4234) but SWA layers route to the scalar AttentionSwaBatched (CudaBackend.cs:4088). Only 1/6 of attention is on tensor cores.

  • Windowed FlashAttention-Tc2 for SWA layers. Bounded contiguous key range (window 512–1024) makes the tiling simpler than the global case. Likely the biggest remaining attention win.

Tier 2 — architecture-level (evaluate)

3. MInference-style dynamic sparse attention (global layers only).
Up to 95% attention FLOP reduction / ~10× prefill (A-shape / Vertical-Slash / Block-Sparse). Caveat: only helps full-context global layers (1/6, several already MQA/KV-shared/k_eq_v); SWA already bounds the rest. Only worth it for long prompts (≥16–32k).

  • Prototype only if long-context is a target workload.

4. Cascade / shared-prefix attention (multi-user).
Compute shared-prefix attention once, merge per-request suffix via online-softmax (FlashInfer cascade). Natural extension of TruncateTo prefix reuse to ContinuousBatchingEngine.

  • Server/batched mode only.

Tier 3 — kernel & scheduling

5. CUDA graphs for prefill chunks.
Graphs are decode-only today. A fixed chunk size → static kernel sequence → capturable; cuts per-launch overhead across the dozens of small launches in GpuLayerBatchedTrunk × L. Helps short/medium prompts. Gemma4CudaGraphParityTests already exists for validation.

  • Capture+replay for fixed-size prefill chunks.

6. Keep hot weights resident in VRAM (skip per-chunk re-dequant).
MatMulBatchedGemm dequants to an fp16 temp per batch; across chunks of one prompt the same weights are re-dequanted. GPU analogue of the CPU _dequantWeightCache (#189), budgeted like SHARPI_PREFILL_DEQUANT_MB. E4B on 12 GB has headroom.

  • VRAM-resident fp16/int8 cache for hottest projections.

Suggested order

  1. Measure (0)
  2. PLE overlap + GPU dequant (1) — most self-contained, Gemma-4-unique
  3. Windowed FlashAttention-Tc2 for SWA (2)
  4. Prefill CUDA graphs (5)
  5. Long-context only: global-layer sparsity (3)

References

Working branch: claude/gemma4-prefill-optimization-h7syay

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions