feat(spec): GPU draft sampled speculative decoding for Gemma 4 (#178)#295
Conversation
The external GPU-draft + batched-verify infra already existed for dense models (#207); two gaps remained for Gemma 4 12B at temp 0.2: Gap A — Gemma 4 was excluded from BatchVerify. SupportsBatchVerify gated on DenseBatchedDecodeSupported() (false for Gemma 4), even though the compute path (BatchForwardMulti -> RunBatchedTrunkGemma4) already handles Gemma 4. Widen the gate to also admit Gemma4BatchedDecodeSupported(). The one-owned-cache-bound-to-k-rows packed verify is correct for Gemma 4's SWA rings / shared-KV / k_eq_v / PLE / softcap: the per-sequence attn loop appends-then-attends in ascending row order, k << window so the append slots are distinct, and appendCtx/effLayerCtx equal the owned cache's ring alloc (so wraps index identically to single-token decode). Gap B — temp>0 sampled accept. SpeculativeDecoder was greedy-only. Add distribution-preserving speculative sampling (Leviathan/Chen) as the default: draft tokens sampled with proposal prob q, accepted with min(1, p/q), rejection resampled from the residual max(0, p-q), full accept drawn from the last verify position; correction/bonus deferred to the next step's certain token (one batched target pass per step, like greedy). --spec-draft-p-min opts into a looser, distribution-diverging accept. Greedy (temp<=0) path is unchanged and byte-stable. New Sampler helpers: BuildFilteredDistribution / SampleWithDistribution / ResampleResidual. CLI relaxes the temp>0 fallback (model-draft only), threads sampling params + the seeded RNG, adds SHARPI_SPEC_SAMPLE=0 as a bisect switch, and rejects repeat-penalty / logit-bias with draft+temp>0 (v1: draft and target must agree on the distribution). Tests: mock distribution-preservation oracle (emitted histogram converges to softmax(target) regardless of draft q) + determinism + greedy-unchanged; Sampler unit tests for the new helpers; Gemma 4 CUDA BatchVerify parity (k=4/6, q8_0 KV, SWA ring-wrap, truncate/commit, e2e). Gemma 4 batched decode is argmax-stable not bit-exact (cuBLAS fp16 GEMM), so the parity oracles assert argmax-or-near-tie + top-5 overlap with maxAbs as a coarse fp16-scale guard. Verified on a 4070 Ti: real E4B Q8_0 BatchVerify incl. the ring-wrap oracle; e2e temp-0.2 sampled spec (q8_0 KV, -g -1) at 73% acceptance. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request implements sampled (distribution-preserving) speculative decoding for temperatures greater than zero, extending speculative decoding support to Gemma-4 configurations on CUDA. It introduces distribution-building and residual-resampling helpers in Sampler, integrates them into SpeculativeDecoder via a new StepSampled path, and adds comprehensive test coverage. The review feedback suggests several performance optimizations: avoiding a third array pass in ResampleResidual by tracking the last positive residual index, skipping redundant distribution building in SampleWithDistribution when temperature is zero or less, and avoiding a redundant distribution rebuild in SpeculativeDecoder.StepSampled by sampling directly from the already populated pDist buffer.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| float target = (float)rng.NextDouble() * sum; | ||
| float cum = 0f; | ||
| for (int i = 0; i < p.Length; i++) | ||
| { | ||
| float r = p[i] - q[i]; | ||
| if (r > 0f) | ||
| { | ||
| cum += r; | ||
| if (target <= cum) return i; | ||
| } | ||
| } | ||
| // Rounding fallback: last token with positive residual. | ||
| for (int i = p.Length - 1; i >= 0; i--) | ||
| if (p[i] - q[i] > 0f) return i; | ||
| return SampleFromDistribution(p, rng); | ||
| } |
There was a problem hiding this comment.
In ResampleResidual, we can optimize the loop to avoid the third pass over the arrays. By tracking the last index with a positive residual during the second pass, we can return it directly if the cumulative sum check is not met due to floating-point inaccuracies. This reduces the worst-case passes from 3 to 2.
float target = (float)rng.NextDouble() * sum;
float cum = 0f;
int lastPositiveIdx = -1;
for (int i = 0; i < p.Length; i++)
{
float r = p[i] - q[i];
if (r > 0f)
{
lastPositiveIdx = i;
cum += r;
if (target <= cum) return i;
}
}
if (lastPositiveIdx >= 0) return lastPositiveIdx;
return SampleFromDistribution(p, rng);| public static int SampleWithDistribution(ReadOnlySpan<float> logits, SamplingParams p, Span<float> probs, Random? rng = null) | ||
| { | ||
| BuildFilteredDistribution(logits, p, probs); | ||
| if (p.Temperature <= 0f) | ||
| return Greedy(logits); | ||
| rng ??= Random.Shared; | ||
| return SampleFromDistribution(probs, rng); | ||
| } |
There was a problem hiding this comment.
In SampleWithDistribution, when p.Temperature <= 0f, we can avoid calling BuildFilteredDistribution entirely. This avoids a redundant call to Greedy(logits) and the overhead of copying/clearing arrays. We also introduce a public SampleFromProbs helper to allow sampling from a pre-computed distribution directly.
public static int SampleWithDistribution(ReadOnlySpan<float> logits, SamplingParams p, Span<float> probs, Random? rng = null)
{
if (p.Temperature <= 0f)
{
int argmax = Greedy(logits);
probs.Clear();
probs[argmax] = 1f;
return argmax;
}
BuildFilteredDistribution(logits, p, probs);
rng ??= Random.Shared;
return SampleFromDistribution(probs, rng);
}
/// <summary>
/// Sample a token from a pre-computed probability distribution.
/// </summary>
public static int SampleFromProbs(ReadOnlySpan<float> probs, Random rng)
{
return SampleFromDistribution(probs, rng);
}| // Reject at i: correction from the residual at this position (true sampling) or a | ||
| // fresh target sample (looser pMin mode — already off-distribution). | ||
| correction = _trueSpecSampling | ||
| ? Sampler.ResampleResidual(pDist, draftDists[i], rng) | ||
| : Sampler.SampleWithDistribution(batch[i - 1], sampling, pDist, rng); |
There was a problem hiding this comment.
Since pDist has already been populated with the filtered distribution of batch[i - 1] at line 390, calling Sampler.SampleWithDistribution here is redundant and will rebuild the entire distribution from scratch. We can optimize this by sampling directly from pDist using the new Sampler.SampleFromProbs helper.
// Reject at i: correction from the residual at this position (true sampling) or a
// fresh target sample (looser pMin mode — already off-distribution).
correction = _trueSpecSampling
? Sampler.ResampleResidual(pDist, draftDists[i], rng)
: Sampler.SampleFromProbs(pDist, rng);Three behavior-preserving cleanups from the PR #295 review (determinism and distribution identical; all 44 mock oracles still pass): - ResampleResidual: track the last positive-residual index during the cumulative pass so the rounding fallback no longer needs a third pass. - SampleWithDistribution: short-circuit temp<=0 to a one-hot (no filter pipeline, no RNG draw); add a public SampleFromProbs for sampling a pre-built distribution. - StepSampled looser-pMin reject: pDist already holds BuildFilteredDistribution(batch[i-1]) from the accept test, so sample it directly via SampleFromProbs instead of rebuilding it. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Closes #178.
The issue's "current limitations" table was stale — the external GPU-draft + N>1 batched-verify infrastructure already existed for dense models (#207). Only two real gaps remained for Gemma 4 12B at temp 0.2.
Gap A — Gemma 4 was excluded from
BatchVerifyCudaForwardPass.SupportsBatchVerifygated onDenseBatchedDecodeSupported()(false for_isGemma4Like), even though the compute path (BatchVerify → BatchForwardMulti → RunBatchedTrunk → RunBatchedTrunkGemma4) already handles Gemma 4. The fix widens the gate to also admitGemma4BatchedDecodeSupported().The real work was proving the one-owned-cache-bound-to-k-rows packed verify is correct for Gemma 4's SWA rings / shared-KV / k_eq_v / PLE / softcap: the per-sequence attention loop appends-then-attends in ascending row order,
k ≪ windowso append slots(startPos+n) % ringSizeare distinct, andappendCtx/effLayerCtxequal the owned cache's ring allocation (so a wrapped ring indexes identically to single-token decode).SupportsPartialRewind => trueunconditionally, so the GPU Gemma draft ctor doesn't throw.Gap B — temp>0 (sampled) accept
SpeculativeDecoderwas greedy-only. Adds distribution-preserving speculative sampling (Leviathan/Chen) as the default: draft tokens sampled with proposal probability q, accepted withmin(1, p/q), rejection resampled from the residualmax(0, p−q), full accept drawn from the last verify position; the correction/bonus is deferred to the next step's certain token (one batched target pass per step, like greedy).--spec-draft-p-minopts into a looser, distribution-diverging accept. The greedy (temp ≤ 0) path is unchanged and byte-stable.Samplerhelpers:BuildFilteredDistribution/SampleWithDistribution/ResampleResidual.SHARPI_SPEC_SAMPLE=0as a bisect switch (→ plain sampled decode), and rejects repeat-penalty / logit-bias with draft+temp>0 (v1: draft and target must agree on the distribution).Tests
softmax(target)regardless of the draft's q — the core invariant. Plus determinism and greedy-unchanged.Samplerunit tests for the new helpers.Validation (4070 Ti, 12 GB)
-g -1): 73% acceptance, coherent output.SHARPI_SPEC_SAMPLE=0falls back cleanly.Throughput here was modest only because E4B-as-draft isn't small relative to the E4B target; the issue's headline ~1.3× needs a 12B target + a genuinely tiny draft. The mechanism is proven.
Kill-switches
SHARPI_SPEC_SAMPLE=0— disable sampled spec at temp>0 (→ plain sampled decode), for bisection.SHARPI_SPEC_BATCH_VERIFY=0— existing sequential-verify fallback.🤖 Generated with Claude Code