Skip to content

perf(cuda): decode CUDA graphs for non-Gemma dense Forward (#156 Item B) #158

@pekkah

Description

@pekkah

Goal

Generalize the Gemma 4 decode CUDA-graph capture/replay (#136) to the non-Gemma per-token CudaForwardPass.Forward, so dense models (Qwen3-8B Q4_K etc.) get the same launch-overhead reduction. Split out of the #156 umbrella as Work Item B.

Decode is HBM-bandwidth-bound (see #142 roofline notes), so the win is not kernel throughput — it's killing the ~1k host launches/token. Same overhead exists on Qwen3-8B's decode loop.

Background

ForwardGemma4 (CudaForwardPass.cs:~1082) is structured as: token-varying prefix (embed/PLE) → capturable device region (RunGemma4DeviceRegion, the layer loop + final norm/output) → logits download. The graph helpers it uses are already generic (nothing Gemma-specific in CudaBackend):
TryBeginGraphCapture / TryEndGraphCaptureAndInstantiate / LaunchGraphForPosition.

The non-Gemma Forward (:~890) inlines embed + layers + output + download in one method with no capture.

Tasks

  • Refactor Forward into: token-varying prefix (embed) → RunDeviceRegion(position) (layer loop + final norm + output projection) → logits download + sync.
  • Keep the TQ ring-advance (:~1064) and the SnapKV Q-capture (:~964) outside the captured region (they mutate host/ring state that breaks static topology).
  • Wrap RunDeviceRegion with the generic capture helpers; mirror TryRunGemma4DeviceRegionViaGraph (:~1304), reusing the _tqEnabled || _kvEvictedCount > 0 bail.
  • Pre-grow any decode scratch (e.g. the dp4a Q8_1 input buffer) before first capture — capture forbids cudaMalloc. Note: the Q4_K matvec Q8_1 scratch currently grows on demand; ensure it's sized before capture.
  • Verify bit-identical: graph replay == direct launch on Qwen3-8B (decode N tokens both ways, assert exact logits). A Gemma4CudaGraphParityTests analogue for dense already-exists as a template.
  • Measure decode t/s on Qwen3-8B Q4_K (SHARPI_CUDA_GRAPH=0 to bisect); update README + close out in perf(cuda): generalize Gemma 4 batched prefill + flash attn + decode graphs to dense Q4_K models #156.

Risk

Lower than #156 Item A — graph replay is numerically identical to direct launch (no new math). Main cost is the Forward refactor; keep TQ/SnapKV state mutation out of the captured region.

Refs: #156 (umbrella), branch perf/cuda-dense-q4k-batched-prefill-156 (Item A, prerequisite context).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions