Goal
Generalize the Gemma 4 decode CUDA-graph capture/replay (#136) to the non-Gemma per-token CudaForwardPass.Forward, so dense models (Qwen3-8B Q4_K etc.) get the same launch-overhead reduction. Split out of the #156 umbrella as Work Item B.
Decode is HBM-bandwidth-bound (see #142 roofline notes), so the win is not kernel throughput — it's killing the ~1k host launches/token. Same overhead exists on Qwen3-8B's decode loop.
Background
ForwardGemma4 (CudaForwardPass.cs:~1082) is structured as: token-varying prefix (embed/PLE) → capturable device region (RunGemma4DeviceRegion, the layer loop + final norm/output) → logits download. The graph helpers it uses are already generic (nothing Gemma-specific in CudaBackend):
TryBeginGraphCapture / TryEndGraphCaptureAndInstantiate / LaunchGraphForPosition.
The non-Gemma Forward (:~890) inlines embed + layers + output + download in one method with no capture.
Tasks
Risk
Lower than #156 Item A — graph replay is numerically identical to direct launch (no new math). Main cost is the Forward refactor; keep TQ/SnapKV state mutation out of the captured region.
Refs: #156 (umbrella), branch perf/cuda-dense-q4k-batched-prefill-156 (Item A, prerequisite context).
Goal
Generalize the Gemma 4 decode CUDA-graph capture/replay (#136) to the non-Gemma per-token
CudaForwardPass.Forward, so dense models (Qwen3-8B Q4_K etc.) get the same launch-overhead reduction. Split out of the #156 umbrella as Work Item B.Decode is HBM-bandwidth-bound (see #142 roofline notes), so the win is not kernel throughput — it's killing the ~1k host launches/token. Same overhead exists on Qwen3-8B's decode loop.
Background
ForwardGemma4(CudaForwardPass.cs:~1082) is structured as: token-varying prefix (embed/PLE) → capturable device region (RunGemma4DeviceRegion, the layer loop + final norm/output) → logits download. The graph helpers it uses are already generic (nothing Gemma-specific inCudaBackend):TryBeginGraphCapture/TryEndGraphCaptureAndInstantiate/LaunchGraphForPosition.The non-Gemma
Forward(:~890) inlines embed + layers + output + download in one method with no capture.Tasks
Forwardinto: token-varying prefix (embed) →RunDeviceRegion(position)(layer loop + final norm + output projection) → logits download + sync.:~1064) and the SnapKV Q-capture (:~964) outside the captured region (they mutate host/ring state that breaks static topology).RunDeviceRegionwith the generic capture helpers; mirrorTryRunGemma4DeviceRegionViaGraph(:~1304), reusing the_tqEnabled || _kvEvictedCount > 0bail.cudaMalloc. Note: the Q4_K matvec Q8_1 scratch currently grows on demand; ensure it's sized before capture.Gemma4CudaGraphParityTestsanalogue for dense already-exists as a template.SHARPI_CUDA_GRAPH=0to bisect); update README + close out in perf(cuda): generalize Gemma 4 batched prefill + flash attn + decode graphs to dense Q4_K models #156.Risk
Lower than #156 Item A — graph replay is numerically identical to direct launch (no new math). Main cost is the
Forwardrefactor; keep TQ/SnapKV state mutation out of the captured region.Refs: #156 (umbrella), branch
perf/cuda-dense-q4k-batched-prefill-156(Item A, prerequisite context).