You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
docs: add MIT LICENSE and prep for public release #1 PLE GPU gather+dequant — still a CPU Parallel.For dequant + host→device upload (CudaForwardPass.cs:2682-2700; the "~30% of batched prefill" in-code note stands). Top lever, and Gemma-4-unique.
Improve prefill (TTFT) throughput for Gemma 4 models on a 12 GB Ada GPU (RTX 4070 Ti, sm_89). This issue captures a ranked attack plan based on profiling the current prefill path and surveying the external state of the art.
Context: what's already solved (do not re-chase)
The prefill path is already mature. The following are implemented and should not be treated as opportunities:
CUDA graphs — decode only (CudaForwardPass.cs:1549, 1858)
FP8 reality check (decision: skip for LLM prefill)
On Ada, fp8 and int8 tensor cores run at the same peak rate (~2× fp16+fp32-accum). Since int8 MMQ already exists, fp8 gives no prefill speed win on this card — its only edge is activation-outlier accuracy at equal throughput. Gemma 4 ships native QAT int4/Q4_0, so the int8 MMQ path is already the right tool. FP4/NVFP4 native tensor-core acceleration is Blackwell-only. FlashAttention-3 (Hopper wgmma/TMA) does not apply to Ada — Tc2 is the correct ceiling here.
Attack plan (ranked by ROI)
0. Measure first
Enable s_prefillProfile and capture the [prefill-profile] breakdown (embed / ple / attn / matmul, CudaForwardPass.cs:2635) on the target Gemma 4 variant + a representative prompt length. Confirm PLE and SWA-attention shares before writing kernels.
1. PLE pre-pass is CPU-bound (~30% of batched prefill).
Per the in-code note at CudaForwardPass.cs:2693: "this serial gather+dequant of N×stackedDim Q8_0 elements was ~30% of the whole batched prefill (#141 profiling)."BuildPerLayerProjectionsBatched (CudaForwardPass.cs:2682) does CPU Parallel.For dequant → full host→device upload → then the layer loop (fully serialized; tPle is a distinct phase). PLE is unique to Gemma 4, so generic LLM prefill work doesn't cover it.
GPU-side PLE gather+dequant — mirror EmbedLookupQ8_0Batched (CudaForwardPass.cs:2607) for the PLE table; index by token id on-device. Removes CPU dequant and the N×stackedDim upload. (E2B/E4B: keep table resident; 12B ~4.2 GB: stream rows / pinned-host UVA gather.)
Overlap — double-buffer: dequant chunk k+1's PLE rows while the GPU runs chunk k's layers (separate copy stream + cudaStreamWaitEvent, primitive already at CuBlasInterop.cs:151).
Cache dequantized PLE rows by token id (LRU in VRAM) — pure function of token id; big win on repeated tokens / shared prefixes.
2. SWA layers have no tensor-core flash path.
Gemma 3/4 use a 5:1 local:global ratio, so 5/6 of layers are sliding-window. Global layers use FlashAttentionPrefillTc2 (CudaBackend.cs:4234) but SWA layers route to the scalarAttentionSwaBatched (CudaBackend.cs:4088). Only 1/6 of attention is on tensor cores.
Windowed FlashAttention-Tc2 for SWA layers. Bounded contiguous key range (window 512–1024) makes the tiling simpler than the global case. Likely the biggest remaining attention win.
Tier 2 — architecture-level (evaluate)
3. MInference-style dynamic sparse attention (global layers only).
Up to 95% attention FLOP reduction / ~10× prefill (A-shape / Vertical-Slash / Block-Sparse). Caveat: only helps full-context global layers (1/6, several already MQA/KV-shared/k_eq_v); SWA already bounds the rest. Only worth it for long prompts (≥16–32k).
Prototype only if long-context is a target workload.
4. Cascade / shared-prefix attention (multi-user).
Compute shared-prefix attention once, merge per-request suffix via online-softmax (FlashInfer cascade). Natural extension of TruncateTo prefix reuse to ContinuousBatchingEngine.
Server/batched mode only.
Tier 3 — kernel & scheduling
5. CUDA graphs for prefill chunks.
Graphs are decode-only today. A fixed chunk size → static kernel sequence → capturable; cuts per-launch overhead across the dozens of small launches in GpuLayerBatchedTrunk × L. Helps short/medium prompts. Gemma4CudaGraphParityTests already exists for validation.
Capture+replay for fixed-size prefill chunks.
6. Keep hot weights resident in VRAM (skip per-chunk re-dequant). MatMulBatchedGemm dequants to an fp16 temp per batch; across chunks of one prompt the same weights are re-dequanted. GPU analogue of the CPU _dequantWeightCache (#189), budgeted like SHARPI_PREFILL_DEQUANT_MB. E4B on 12 GB has headroom.
VRAM-resident fp16/int8 cache for hottest projections.
Suggested order
Measure (0)
PLE overlap + GPU dequant (1) — most self-contained, Gemma-4-unique
Goal
Improve prefill (TTFT) throughput for Gemma 4 models on a 12 GB Ada GPU (RTX 4070 Ti, sm_89). This issue captures a ranked attack plan based on profiling the current prefill path and surveying the external state of the art.
Context: what's already solved (do not re-chase)
The prefill path is already mature. The following are implemented and should not be treated as opportunities:
CudaForwardPass.PrefillBatchedTrunk(CudaForwardPass.cs:2585)CudaForwardPass.cs:200)CudaBackend.cs:4234CudaBackend.cs:2308CudaBackend.cs:2213CudaForwardPass.cs:1549, 1858)FP8 reality check (decision: skip for LLM prefill)
On Ada, fp8 and int8 tensor cores run at the same peak rate (~2× fp16+fp32-accum). Since int8 MMQ already exists, fp8 gives no prefill speed win on this card — its only edge is activation-outlier accuracy at equal throughput. Gemma 4 ships native QAT int4/Q4_0, so the int8 MMQ path is already the right tool. FP4/NVFP4 native tensor-core acceleration is Blackwell-only. FlashAttention-3 (Hopper wgmma/TMA) does not apply to Ada — Tc2 is the correct ceiling here.
Attack plan (ranked by ROI)
0. Measure first
s_prefillProfileand capture the[prefill-profile]breakdown (embed / ple / attn / matmul,CudaForwardPass.cs:2635) on the target Gemma 4 variant + a representative prompt length. Confirm PLE and SWA-attention shares before writing kernels.Tier 1 — Gemma-4-specific, codebase-confirmed bottlenecks
1. PLE pre-pass is CPU-bound (~30% of batched prefill).
Per the in-code note at
CudaForwardPass.cs:2693: "this serial gather+dequant of N×stackedDim Q8_0 elements was ~30% of the whole batched prefill (#141 profiling)."BuildPerLayerProjectionsBatched(CudaForwardPass.cs:2682) does CPUParallel.Fordequant → full host→device upload → then the layer loop (fully serialized;tPleis a distinct phase). PLE is unique to Gemma 4, so generic LLM prefill work doesn't cover it.EmbedLookupQ8_0Batched(CudaForwardPass.cs:2607) for the PLE table; index by token id on-device. Removes CPU dequant and the N×stackedDim upload. (E2B/E4B: keep table resident; 12B ~4.2 GB: stream rows / pinned-host UVA gather.)cudaStreamWaitEvent, primitive already atCuBlasInterop.cs:151).2. SWA layers have no tensor-core flash path.
Gemma 3/4 use a 5:1 local:global ratio, so 5/6 of layers are sliding-window. Global layers use
FlashAttentionPrefillTc2(CudaBackend.cs:4234) but SWA layers route to the scalarAttentionSwaBatched(CudaBackend.cs:4088). Only 1/6 of attention is on tensor cores.Tier 2 — architecture-level (evaluate)
3. MInference-style dynamic sparse attention (global layers only).
Up to 95% attention FLOP reduction / ~10× prefill (A-shape / Vertical-Slash / Block-Sparse). Caveat: only helps full-context global layers (1/6, several already MQA/KV-shared/k_eq_v); SWA already bounds the rest. Only worth it for long prompts (≥16–32k).
4. Cascade / shared-prefix attention (multi-user).
Compute shared-prefix attention once, merge per-request suffix via online-softmax (FlashInfer cascade). Natural extension of
TruncateToprefix reuse toContinuousBatchingEngine.Tier 3 — kernel & scheduling
5. CUDA graphs for prefill chunks.
Graphs are decode-only today. A fixed chunk size → static kernel sequence → capturable; cuts per-launch overhead across the dozens of small launches in
GpuLayerBatchedTrunk× L. Helps short/medium prompts.Gemma4CudaGraphParityTestsalready exists for validation.6. Keep hot weights resident in VRAM (skip per-chunk re-dequant).
MatMulBatchedGemmdequants to an fp16 temp per batch; across chunks of one prompt the same weights are re-dequanted. GPU analogue of the CPU_dequantWeightCache(#189), budgeted likeSHARPI_PREFILL_DEQUANT_MB. E4B on 12 GB has headroom.Suggested order
References
Working branch:
claude/gemma4-prefill-optimization-h7syay