Skip to content

feat(spec): GPU draft sampled speculative decoding for Gemma 4 (#178)#295

Merged
pekkah merged 2 commits into
masterfrom
feat/178-gemma4-gpu-draft-sampled-spec
Jun 18, 2026
Merged

feat(spec): GPU draft sampled speculative decoding for Gemma 4 (#178)#295
pekkah merged 2 commits into
masterfrom
feat/178-gemma4-gpu-draft-sampled-spec

Conversation

@pekkah

@pekkah pekkah commented Jun 18, 2026

Copy link
Copy Markdown
Owner

Closes #178.

The issue's "current limitations" table was stale — the external GPU-draft + N>1 batched-verify infrastructure already existed for dense models (#207). Only two real gaps remained for Gemma 4 12B at temp 0.2.

Gap A — Gemma 4 was excluded from BatchVerify

CudaForwardPass.SupportsBatchVerify gated on DenseBatchedDecodeSupported() (false for _isGemma4Like), even though the compute path (BatchVerify → BatchForwardMulti → RunBatchedTrunk → RunBatchedTrunkGemma4) already handles Gemma 4. The fix widens the gate to also admit Gemma4BatchedDecodeSupported().

The real work was proving the one-owned-cache-bound-to-k-rows packed verify is correct for Gemma 4's SWA rings / shared-KV / k_eq_v / PLE / softcap: the per-sequence attention loop appends-then-attends in ascending row order, k ≪ window so append slots (startPos+n) % ringSize are distinct, and appendCtx/effLayerCtx equal the owned cache's ring allocation (so a wrapped ring indexes identically to single-token decode). SupportsPartialRewind => true unconditionally, so the GPU Gemma draft ctor doesn't throw.

Gap B — temp>0 (sampled) accept

SpeculativeDecoder was greedy-only. Adds distribution-preserving speculative sampling (Leviathan/Chen) as the default: draft tokens sampled with proposal probability q, accepted with min(1, p/q), rejection resampled from the residual max(0, p−q), full accept drawn from the last verify position; the correction/bonus is deferred to the next step's certain token (one batched target pass per step, like greedy). --spec-draft-p-min opts into a looser, distribution-diverging accept. The greedy (temp ≤ 0) path is unchanged and byte-stable.

  • New Sampler helpers: BuildFilteredDistribution / SampleWithDistribution / ResampleResidual.
  • CLI relaxes the temp>0 fallback (model-draft only — prompt-lookup has no q), threads sampling params + the seeded RNG, adds SHARPI_SPEC_SAMPLE=0 as a bisect switch (→ plain sampled decode), and rejects repeat-penalty / logit-bias with draft+temp>0 (v1: draft and target must agree on the distribution).

Tests

  • Distribution-preservation oracle (mock): emitted histogram converges to softmax(target) regardless of the draft's q — the core invariant. Plus determinism and greedy-unchanged.
  • Sampler unit tests for the new helpers.
  • Gemma 4 CUDA BatchVerify parity: k=4/6, q8_0 KV, SWA ring-wrap, truncate/commit rollback, e2e greedy. Gemma 4 batched decode is argmax-stable (not bit-exact — cuBLAS fp16 GEMM), so the oracles assert argmax-or-near-tie + top-5 overlap with maxAbs as a coarse fp16-scale guard.

Validation (4070 Ti, 12 GB)

  • 99 tests pass, including the ring-wrap oracle running against real E4B Q8_0 (the high-risk case).
  • End-to-end CLI (E4B Q8_0 target + E4B q4_0 draft, temp 0.2, q8_0 KV, -g -1): 73% acceptance, coherent output. SHARPI_SPEC_SAMPLE=0 falls back cleanly.
  • Release build clean (warnings-as-errors + AOT analyzers).

Throughput here was modest only because E4B-as-draft isn't small relative to the E4B target; the issue's headline ~1.3× needs a 12B target + a genuinely tiny draft. The mechanism is proven.

Kill-switches

  • SHARPI_SPEC_SAMPLE=0 — disable sampled spec at temp>0 (→ plain sampled decode), for bisection.
  • SHARPI_SPEC_BATCH_VERIFY=0 — existing sequential-verify fallback.

🤖 Generated with Claude Code

The external GPU-draft + batched-verify infra already existed for dense
models (#207); two gaps remained for Gemma 4 12B at temp 0.2:

Gap A — Gemma 4 was excluded from BatchVerify. SupportsBatchVerify gated
on DenseBatchedDecodeSupported() (false for Gemma 4), even though the
compute path (BatchForwardMulti -> RunBatchedTrunkGemma4) already handles
Gemma 4. Widen the gate to also admit Gemma4BatchedDecodeSupported(). The
one-owned-cache-bound-to-k-rows packed verify is correct for Gemma 4's
SWA rings / shared-KV / k_eq_v / PLE / softcap: the per-sequence attn loop
appends-then-attends in ascending row order, k << window so the append
slots are distinct, and appendCtx/effLayerCtx equal the owned cache's ring
alloc (so wraps index identically to single-token decode).

Gap B — temp>0 sampled accept. SpeculativeDecoder was greedy-only. Add
distribution-preserving speculative sampling (Leviathan/Chen) as the
default: draft tokens sampled with proposal prob q, accepted with
min(1, p/q), rejection resampled from the residual max(0, p-q), full
accept drawn from the last verify position; correction/bonus deferred to
the next step's certain token (one batched target pass per step, like
greedy). --spec-draft-p-min opts into a looser, distribution-diverging
accept. Greedy (temp<=0) path is unchanged and byte-stable.

New Sampler helpers: BuildFilteredDistribution / SampleWithDistribution /
ResampleResidual. CLI relaxes the temp>0 fallback (model-draft only),
threads sampling params + the seeded RNG, adds SHARPI_SPEC_SAMPLE=0 as a
bisect switch, and rejects repeat-penalty / logit-bias with draft+temp>0
(v1: draft and target must agree on the distribution).

Tests: mock distribution-preservation oracle (emitted histogram converges
to softmax(target) regardless of draft q) + determinism + greedy-unchanged;
Sampler unit tests for the new helpers; Gemma 4 CUDA BatchVerify parity
(k=4/6, q8_0 KV, SWA ring-wrap, truncate/commit, e2e). Gemma 4 batched
decode is argmax-stable not bit-exact (cuBLAS fp16 GEMM), so the parity
oracles assert argmax-or-near-tie + top-5 overlap with maxAbs as a coarse
fp16-scale guard.

Verified on a 4070 Ti: real E4B Q8_0 BatchVerify incl. the ring-wrap
oracle; e2e temp-0.2 sampled spec (q8_0 KV, -g -1) at 73% acceptance.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements sampled (distribution-preserving) speculative decoding for temperatures greater than zero, extending speculative decoding support to Gemma-4 configurations on CUDA. It introduces distribution-building and residual-resampling helpers in Sampler, integrates them into SpeculativeDecoder via a new StepSampled path, and adds comprehensive test coverage. The review feedback suggests several performance optimizations: avoiding a third array pass in ResampleResidual by tracking the last positive residual index, skipping redundant distribution building in SampleWithDistribution when temperature is zero or less, and avoiding a redundant distribution rebuild in SpeculativeDecoder.StepSampled by sampling directly from the already populated pDist buffer.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +474 to +489
float target = (float)rng.NextDouble() * sum;
float cum = 0f;
for (int i = 0; i < p.Length; i++)
{
float r = p[i] - q[i];
if (r > 0f)
{
cum += r;
if (target <= cum) return i;
}
}
// Rounding fallback: last token with positive residual.
for (int i = p.Length - 1; i >= 0; i--)
if (p[i] - q[i] > 0f) return i;
return SampleFromDistribution(p, rng);
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In ResampleResidual, we can optimize the loop to avoid the third pass over the arrays. By tracking the last index with a positive residual during the second pass, we can return it directly if the cumulative sum check is not met due to floating-point inaccuracies. This reduces the worst-case passes from 3 to 2.

        float target = (float)rng.NextDouble() * sum;
        float cum = 0f;
        int lastPositiveIdx = -1;
        for (int i = 0; i < p.Length; i++)
        {
            float r = p[i] - q[i];
            if (r > 0f)
            {
                lastPositiveIdx = i;
                cum += r;
                if (target <= cum) return i;
            }
        }
        if (lastPositiveIdx >= 0) return lastPositiveIdx;
        return SampleFromDistribution(p, rng);

Comment on lines +172 to +179
public static int SampleWithDistribution(ReadOnlySpan<float> logits, SamplingParams p, Span<float> probs, Random? rng = null)
{
BuildFilteredDistribution(logits, p, probs);
if (p.Temperature <= 0f)
return Greedy(logits);
rng ??= Random.Shared;
return SampleFromDistribution(probs, rng);
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In SampleWithDistribution, when p.Temperature <= 0f, we can avoid calling BuildFilteredDistribution entirely. This avoids a redundant call to Greedy(logits) and the overhead of copying/clearing arrays. We also introduce a public SampleFromProbs helper to allow sampling from a pre-computed distribution directly.

    public static int SampleWithDistribution(ReadOnlySpan<float> logits, SamplingParams p, Span<float> probs, Random? rng = null)
    {
        if (p.Temperature <= 0f)
        {
            int argmax = Greedy(logits);
            probs.Clear();
            probs[argmax] = 1f;
            return argmax;
        }
        BuildFilteredDistribution(logits, p, probs);
        rng ??= Random.Shared;
        return SampleFromDistribution(probs, rng);
    }

    /// <summary>
    /// Sample a token from a pre-computed probability distribution.
    /// </summary>
    public static int SampleFromProbs(ReadOnlySpan<float> probs, Random rng)
    {
        return SampleFromDistribution(probs, rng);
    }

Comment on lines +411 to +415
// Reject at i: correction from the residual at this position (true sampling) or a
// fresh target sample (looser pMin mode — already off-distribution).
correction = _trueSpecSampling
? Sampler.ResampleResidual(pDist, draftDists[i], rng)
: Sampler.SampleWithDistribution(batch[i - 1], sampling, pDist, rng);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since pDist has already been populated with the filtered distribution of batch[i - 1] at line 390, calling Sampler.SampleWithDistribution here is redundant and will rebuild the entire distribution from scratch. We can optimize this by sampling directly from pDist using the new Sampler.SampleFromProbs helper.

            // Reject at i: correction from the residual at this position (true sampling) or a
            // fresh target sample (looser pMin mode — already off-distribution).
            correction = _trueSpecSampling
                ? Sampler.ResampleResidual(pDist, draftDists[i], rng)
                : Sampler.SampleFromProbs(pDist, rng);

Three behavior-preserving cleanups from the PR #295 review (determinism and
distribution identical; all 44 mock oracles still pass):

- ResampleResidual: track the last positive-residual index during the
  cumulative pass so the rounding fallback no longer needs a third pass.
- SampleWithDistribution: short-circuit temp<=0 to a one-hot (no filter
  pipeline, no RNG draw); add a public SampleFromProbs for sampling a
  pre-built distribution.
- StepSampled looser-pMin reject: pDist already holds
  BuildFilteredDistribution(batch[i-1]) from the accept test, so sample it
  directly via SampleFromProbs instead of rebuilding it.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@pekkah pekkah merged commit a16b179 into master Jun 18, 2026
1 check passed
@pekkah pekkah deleted the feat/178-gemma4-gpu-draft-sampled-spec branch June 18, 2026 10:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GPU draft-MTP speculative decoding for Gemma 4 12B (decode 54 → ~70 t/s)

1 participant