Skip to content

perf(gemma4-cuda): CUDA Graph decode + CUDA 12 align + SnapKV post-eviction fix (#136)#139

Merged
pekkah merged 6 commits into
masterfrom
perf/gemma4-cuda-graphs-136
Jun 5, 2026
Merged

perf(gemma4-cuda): CUDA Graph decode + CUDA 12 align + SnapKV post-eviction fix (#136)#139
pekkah merged 6 commits into
masterfrom
perf/gemma4-cuda-graphs-136

Conversation

@pekkah

@pekkah pekkah commented Jun 5, 2026

Copy link
Copy Markdown
Owner

What

CUDA Graph capture/replay for the launch-bound Gemma 4 CUDA decode loop (issue #136), plus
a CUDA-12 runtime alignment and a fix for a pre-existing SnapKV decode bug found along the way.

1. CUDA Graph capture/replay (#136)

Capture the static-topology decode region once on the first decode token and replay it per
token — one cuGraphLaunch + a handful of node-param updates instead of ~1k cuLaunchKernel.
Implemented for both the all-GPU CudaForwardPass (layer loop + final output) and the hybrid
CudaHybridForwardPass (GPU layer loop). Only position varies per token; the five
position-varying ops (RoPE / RoPEWithFactors / KvAppend / Attention / AttentionSwa)
self-register their graph node during capture via cuStreamGetCaptureInfo_v2, and replay
rewrites just their position-derived scalars with cuGraphExecKernelNodeSetParams (v1
CUDA_KERNEL_NODE_PARAMS ABI). Driver-API bindings (nvcuda, toolkit-independent) live in
NvrtcInterop; the shared capture/replay facility lives in CudaBackend.

Off by default behind SHARPI_CUDA_GRAPH (or the UseCudaGraph setter), with a
direct-launch fallback on any capture/replay failure. Gated off under SnapKV eviction and
TurboQuant (they break the captured seqLen == position+1 invariant).

Measured (gemma-4-E4B-it-Q8_0, RTX 4070 Ti): all-GPU decode +5–6% at ~1K ctx
(42→45 t/s); hybrid +1.6%; prefill unchanged. But −9.7% at short context (the ~168
cuGraphExecKernelNodeSetParams/token overhead exceeds the launch-collapse savings when
per-token GPU work is small) — which is why it stays default-off. Decode is largely
memory-bandwidth-bound at realistic ctx after #137's launch-fusions, so the win is single-digit.
Bit-parity proven (graph-on logits bit-identical to direct launches across decode steps, both
paths) in Gemma4CudaGraphParityTests.

2. CUDA 12.x runtime alignment

Pin cudart64_12 / cublas64_12 (was the 11.x SONAMEs) so the runtime/cuBLAS match the NVRTC
the kernels are JIT'd with (nvrtc64_120_0). Driver API + graph calls are toolkit-independent.
No parity change.

3. Fix: SnapKV post-eviction decode on the dense CudaForwardPass

Found while gating graphs against SnapKV: after eviction compacts the KV cache to K entries,
the dense decode kept passing the absolute position to KvAppend (slot) and Attention
(seqLen), so it wrote past the compacted region and attended stale/duplicated slots — silently
undoing the eviction (the hybrid-GDN path already did this right). Fix: index by the physical
slot position - _kvEvictedCount (RoPE keeps the logical position); bit-identical when nothing
evicted. Also disable SnapKV for Gemma-4-style models — their SWA ring caches + per-layer
head_dim make ApplySnapKvEviction incoherent regardless of indexing. New oracle asserts
KvLength stays full even with an explicit over-budget budget. The old finite/≥2-distinct test
was too weak to catch the bug.

Test

  • Gemma4CudaGraphParityTests (3): all-GPU, hybrid, SnapKV-configured-no-evict — bit-identical.
  • Gemma4CudaForwardPassTests incl. new SnapKvDisabled_NoEvictionEvenOverBudget.
  • CudaForwardPassSnapKvTests (Qwen3-8B): post-eviction path still coherent.
  • Full non-ForwardPass suite (Core/Server/TurboQuant/Pipeline) green; clean Release build (0 warnings, AOT-safe).

Follow-ups (deferred)

  • Device-side position (1 memcpy + 1 launch vs ~168 node-param updates) to remove the
    short-context regression and earn a default-on flip.

🤖 Generated with Claude Code

pekkah and others added 4 commits June 5, 2026 14:09
cuda: pin runtime to CUDA 12.x to match NVRTC (cudart64_12 / cublas64_12)

The backend JIT-compiles its kernels with nvrtc64_120_0 (CUDA 12.x) but bound
cuBLAS/cudart to the 11.x SONAMEs (cublas64_11 / cudart64_110). Align both on
12.x so the runtime, cuBLAS, and NVRTC come from one toolkit generation. The
driver API (nvcuda) and the CUDA-graph calls are toolkit-independent. On a host
with several toolkits installed the loader picks the first 12.x bin on PATH.

No parity change: 15 Gemma 4 CUDA tests + 312 non-ForwardPass tests stay green.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@
Gemma 4 decode launches ~1k tiny kernels/token. Capture the decode region once
on the first decode token and replay it per token, so the host issues one
cuGraphLaunch plus a handful of node-param updates instead of ~1k cuLaunchKernel.

The decode topology is static across tokens; only position varies. The region is
pure on-device compute (transformer layers + final norm/output/softcap for the
all-GPU path; the GPU layer loop for the hybrid path) -- embedding, PLE, uploads,
and the logits download stay outside, since host transfers cannot be captured. The
five position-varying ops (RoPE / RoPEWithFactors / KvAppend / Attention /
AttentionSwa) self-register their graph node during capture via
cuStreamGetCaptureInfo_v2; replay rewrites just their position-derived scalars with
cuGraphExecKernelNodeSetParams (v1 CUDA_KERNEL_NODE_PARAMS ABI).

Driver-API graph bindings live in NvrtcInterop (nvcuda, toolkit-independent); the
shared capture/replay facility (TryBeginGraphCapture / TryEndGraphCaptureAndInstantiate
/ LaunchGraphForPosition) lives in CudaBackend. Off by default behind SHARPI_CUDA_GRAPH
(or the UseCudaGraph setter); any capture/replay failure disables graphs for the
session and falls back to direct launches. Gated off when SnapKV or TurboQuant is
active (they break the captured seqLen == position+1 invariant) -- SnapKV
coexistence is a follow-up.

Bit-parity: Gemma4CudaGraphParityTests asserts graph-on logits are bit-identical to
direct launches across decode steps, for both the all-GPU and hybrid paths.

Measured (gemma-4-E4B-it-Q8_0, RTX 4070 Ti, ~1K ctx, warm): all-GPU decode
42.4 -> 45.0 t/s (+6.1%) when eligible (SnapKV off); hybrid 6.40 -> 6.50 t/s
(+1.6%); prefill unchanged. Decode is largely memory-bandwidth-bound at realistic
context after the #137 launch-fusions, so graphs reclaim only the launch time not
already hidden behind GPU work.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…unevicted (#136)

The first graph cut disabled CUDA-graph decode whenever a SnapKV budget was
*configured* (_snapKvEffectiveBudget > 0). But SnapKV only breaks the graph's
seqLen == position+1 invariant when it actually *evicts* and compacts the KV cache
(K < N); a configured budget that the prompt never exceeds leaves the cache filling
sequentially, so graphs stay valid there.

Gate on a new _snapKvEvicted flag instead: set in ApplySnapKvEviction once the cache
is compacted, cleared on a full ResetCache. This unlocks the default config, where
SnapKV auto-enables (budget 1024) but typical prompts fit under it.

New oracle Gemma4_E4B_CudaGraph_AllGpu_SnapKvConfiguredNoEvict_BitMatches: with an
explicit SHARPI_SNAPKV_BUDGET and an under-budget prompt, graphs must engage
(GraphReady) and stay bit-identical to direct launches across decode steps.

Measured (gemma-4-E4B-it-Q8_0, RTX 4070 Ti, default config = auto-SnapKV on,
972-token prompt, no eviction): all-GPU decode 42.7 -> 44.9 t/s (+5.2%), in line with
the +6.1% SnapKV-off number. Post-eviction (prompt > budget) still falls back to
direct launches.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…rdPass

After SnapKV evicts, ApplySnapKvEviction compacts each layer's KV cache to K
entries at slots [0, K) and sets _kvLength = K. But the dense decode passed the
absolute `position` straight to KvAppend (slot) and Attention (seqLen = position+1),
so once the cache was compacted (cache-fill K != absolute position) it wrote new
tokens past the compacted region and attended across the stale/duplicated slots
between K and the logical position -- silently undoing the eviction and, near the
context end, trending toward an out-of-bounds slot. The hybrid-GDN path already does
this right (CudaHybridGdnForwardPass.GpuAttnBlock passes kvPosition = _kvCache.Length
as the physical slot while RoPE keeps the logical position); the dense path didn't.

The existing CudaForwardPassSnapKv test only checked finite + >=2 distinct logits,
which the bug passes because attending to *more* (un-evicted) context still looks
coherent -- so it slipped through.

Fix, mirroring the hybrid path:

- Track _kvEvictedCount (= N - K at eviction, 0 otherwise). Decode indexes the cache
  by the physical slot `position - _kvEvictedCount` for both KvAppend and Attention
  seqLen; RoPE keeps the logical `position`. When nothing was evicted the value is 0,
  so behavior is bit-identical to before. Applied to Forward and ForwardProfiled.
  (This subsumes the earlier _snapKvEvicted bool used by the CUDA-graph guard:
  graphs now gate on _kvEvictedCount > 0.)

- Disable SnapKV for Gemma-4-style models. Their SWA layers use sliding-window ring
  caches and layers carry per-layer head_dim, so the full-context scoring and
  uniform-kvDim compaction in ApplySnapKvEviction can't be applied coherently to them
  at all (out-of-range gather, wrong row stride) -- the position remap isn't enough.
  Force the budget to 0 for the Gemma-4-like path (warn if explicitly requested) and
  use the full KV cache. New oracle Gemma4_E4B_CudaForward_SnapKvDisabled_* asserts
  KvLength stays at the full prompt length even with an explicit over-budget budget.

Affects the all-GPU CUDA path (-g -1) for any model whenever SnapKV evicts (prompt >
budget; default auto-budget 1024). The CUDA-graph decode (#136) was already safe --
its guard only ran graphs before any eviction.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces CUDA Graphs support (issue #136) for the Gemma 4 decode loop across both the all-GPU and hybrid forward passes, significantly reducing host launch overhead by capturing the static topology and updating only position-derived kernel parameters. It also updates CUDA runtime and cuBLAS bindings to target CUDA 12.x, disables SnapKV for Gemma-4-style models to prevent cache corruption, and adds comprehensive bit-parity tests to ensure graph replay matches direct launches. The feedback suggests replacing the magic number 16 in TrackPositionNode with a shared named constant to align with the stackalloc buffer size in LaunchGraphForPosition.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread src/SharpInference.Cuda/CudaBackend.cs Outdated
int rc = NvrtcInterop.StreamGetCaptureInfo(
_stream, out int status, out _, out _, out nint deps, out nuint numDeps);
if (rc != 0 || status != NvrtcInterop.CU_STREAM_CAPTURE_STATUS_ACTIVE
|| numDeps != 1 || deps == nint.Zero || argValues.Length > 16)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better maintainability and to prevent potential stack overflows, the magic number 16 should be replaced with a named constant. This constant should be shared with LaunchGraphForPosition where the stackalloc buffer size is defined.

Please consider defining a class-level constant like private const int MaxGraphKernelArgs = 16; and using it here and in LaunchGraphForPosition.

                    || numDeps != 1 || deps == nint.Zero || argValues.Length > MaxGraphKernelArgs)

pekkah and others added 2 commits June 5, 2026 16:16
review(gemma4-cuda): address pr-review-toolkit findings (#139)

- AbortGraphCapture now DiscardGraph()s unconditionally: a LaunchGraphForPosition
  failure reaches it after instantiate already built _graphExec/_capturedGraph
  (capture already ended), so the prior guard leaked the exec graph and left
  GraphReady stuck true. (silent-failure-hunter, MEDIUM)
- Fix stale SONAME comment in NvrtcInterop (cudart64_12/cublas64_12, not _110/_11).
- Document that TruncateTo intentionally preserves _kvEvictedCount (decode-token
  rewind keeps the physical mapping valid).
- Document the capture-region no-cudaMalloc requirement (Q8_1 buffer pre-sizing).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@
review(gemma4-cuda): share GraphMaxKernelArgs constant (gemini, #139)

Replace the magic 16 in TrackPositionNode + the local MaxArgs in
LaunchGraphForPosition with a single GraphMaxKernelArgs const so the snapshot
bound and the stackalloc cell/ptr buffer size cannot drift apart.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@
@pekkah

pekkah commented Jun 5, 2026

Copy link
Copy Markdown
Owner Author

/gemini review

@pekkah pekkah merged commit d84cc62 into master Jun 5, 2026
1 check passed
@pekkah pekkah deleted the perf/gemma4-cuda-graphs-136 branch June 5, 2026 13:21

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request upgrades the CUDA runtime and cuBLAS bindings to CUDA 12.x and introduces CUDA Graphs support (issue #136) to optimize the Gemma 4 decode loop by capturing and replaying the decode region. It also force-disables SnapKV for Gemma-4-style models to prevent cache corruption and adds corresponding parity and regression tests. The review feedback highlights a potential resource leak in both CudaForwardPass and CudaHybridForwardPass where an exception during graph execution does not immediately release captured graph resources; calling AbortGraphCapture in the catch blocks is recommended to ensure immediate cleanup.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

if (_graphCaptured && _gpu.GraphReady)
{
try { _gpu.LaunchGraphForPosition(position); return true; }
catch { _useCudaGraph = false; _graphCaptured = false; return false; }

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In case of an exception during LaunchGraphForPosition, the captured CUDA graph resources are not released until Dispose is called. This could lead to a resource leak for the duration of the session. It would be better to call _gpu.AbortGraphCapture() here to release the resources immediately, similar to how it's handled in the catch block for the initial graph capture.

            catch { _gpu.AbortGraphCapture(); _useCudaGraph = false; _graphCaptured = false; return false; }

if (_graphCaptured && _gpu.GraphReady)
{
try { _gpu.LaunchGraphForPosition(position); return true; }
catch { _useCudaGraph = false; _graphCaptured = false; return false; }

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the all-GPU forward pass, an exception during LaunchGraphForPosition could lead to a resource leak for the session's duration because the CUDA graph resources are not released. Calling _gpu.AbortGraphCapture() in the catch block would ensure immediate cleanup.

            catch { _gpu.AbortGraphCapture(); _useCudaGraph = false; _graphCaptured = false; return false; }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant