Skip to content

bf16 + q8_0 KV cache for the dense CUDA path → 128K on Gemma 4 12B in 12 GB (#179)#184

Merged
pekkah merged 10 commits into
masterfrom
feat/bf16-kv-dense-cuda-179
Jun 9, 2026
Merged

bf16 + q8_0 KV cache for the dense CUDA path → 128K on Gemma 4 12B in 12 GB (#179)#184
pekkah merged 10 commits into
masterfrom
feat/bf16-kv-dense-cuda-179

Conversation

@pekkah

@pekkah pekkah commented Jun 9, 2026

Copy link
Copy Markdown
Owner

Adds opt-in narrowed KV-cache dtypes (--kv-type bf16|q8_0 / SHARPI_KV_DTYPE) to the dense CUDA forward pass, plus a host-allocation fix that was the real context ceiling. Takes the Gemma 4 12B QAT (-g -1) from a ~32K ceiling to -c 131072 (128K) within 12 GB, coherent. Default KV stays fp32; both narrowed dtypes are opt-in.

Result (Gemma 4 12B QAT Q4_0, RTX 4070 Ti 12 GB, -g -1)

KV dtype Max ctx in 12 GB Decode @128k
fp32 (default) ~32K (cudaMalloc-fails at 64K)
bf16 128K 18.8 t/s (embed-table cliff)
q8_0 128K 53.3 t/s — full speed (2.8× over bf16)

At 128K the tied ~818 MB Q6_K embed/LM-head table can't stay GPU-resident under bf16's KV (~350 MB free) and spills to host per token → decode cliffs. q8_0's block-quantized cache (~¼ fp32, ~½ bf16) frees ~1.1 GB more, keeps the table resident, and holds full speed. Same coherent output across all three dtypes.

What's in here (squashes #179 increments 1 → 2)

  • Increment 1 — bf16 KV storage + SWA/global bf16 kernels, per-token decode/prefill dispatch, --kv-type CLI flag (rejects bf16+TQ / explicit-SnapKV).
  • Increment 1.5a/b — bf16 scalar batched prefill (≤4096) + bf16 Tc2 tensor-core flash prefill incl. chunked >4096.
  • Host-cache fixKvCache.CreateBookkeepingOnly: the GPU passes kept the real KV in VRAM but still allocated a full host KV buffer used only for the position counter (~50 GB at 64K on the 12B → managed OutOfMemoryException). This was the hidden ceiling, masking the bf16 win.
  • Increment 2 — q8_0 KV (this PR's headline). ggml block_q8_0 layout (32 int8 + fp16 scale = 34 B/block). One sharpi_kvload(const block_q8_0*) overload is the single variation point; the 6 bf16 read kernels were templated on <KV> with bf16/q8_0 thunks (bf16 thunk byte-identical, since sharpi_kvload(unsigned short) == sharpi_bf16_to_fp32). New store kernels quantize per 32-lane warp (amax → d=amax/127 → rintf clamp ±127, mirroring llm_quantize_q8_1). CudaBackend.Allocate now sizes via DTypeInfo.ByteSize (identical for scalar, adds block-quantized).

Validation

  • Parity oracles (logit-level top-5 argmax-stable vs an independent fp32 trajectory, teacher-forced so KV dtype is the only variable, per feedback_cross_backend_parity_test):
    • bf16: per-token + batched (Qwen3-8B, E4B) + chunked (E4B) — all green, preserved through the templating.
    • q8_0: per-token + batched (Qwen3-8B, E4B, 12B) + chunked >4096 (E4B and 12B/k_eq_v) — all green.
    • DTypeInfo.ByteSize(.., Q8_0)==34 CPU unit assert.
  • CLI A/B at -c 131072 on the 12B: q8_0 53.3 t/s vs bf16 18.8 t/s, both coherent.

Review (pr-review-toolkit)

code-reviewer: no critical/important issues. Addressed: a fail-loud guard for the q8_0 kvDim % 32 invariant (was a latent silent under-allocation; all current models satisfy it), the 12B chunked test above, the ByteSize unit test, and 6 stale kernel-header comments. Knowingly deferred: the _ => fp32 dtype-selector fall-through (future-proofing only — _kvDType is centrally validated to {fp32,bf16,q8_0}) and a pre-existing EnsureImageKernels logging nit.

Notes / follow-ups

  • Non-SWA scalar batched q8 kernels + the wave (>4096 global) q8 kernel are wired but only reachable by a non-%64 head_dim model (none on disk) — same latent-but-unexercised status as their bf16 twins.
  • Open follow-ups: auto-narrow default when ctx won't fit fp32; Tc/half2-flash q8 thunks; q8_0 draft-model headroom (GPU draft-MTP speculative decoding for Gemma 4 12B (decode 54 → ~70 t/s) #178).

🤖 Generated with Claude Code

pekkah and others added 7 commits June 9, 2026 13:03
Adds opt-in half-width (bf16) KV storage to the dense CudaForwardPass,
halving the KV footprint (~2x context headroom) on a 12 GB card.
Arithmetic stays fp32 in the kernels — only the cache store is narrowed —
so decode is argmax-stable vs fp32 KV.

- SHARPI_KV_DTYPE (fp32|bf16) wired into CudaForwardPass; default fp32
  (opt-in until long-context validated). Rejects bf16+TQ and bf16+explicit
  SnapKV; disables auto-SnapKV under bf16.
- New SWA bf16 kernels llm_attention_swa_bf16 / _batched_bf16 (compose the
  SWA windowing/ring with bf16 K/V decode); backend wrappers + handles.
- attn_scale added to llm_attention_bf16 (+ wrapper) so Gemma 4 globals
  (scale=1.0) work with bf16 KV; GDN callers keep the rsqrt default.
- Graph-capture position tracking added to KvAppendBf16/AttentionBf16 so the
  full bf16 decode region is CUDA-graph-safe (GDN path never captures, so
  its behavior is unchanged).
- Per-token decode/prefill dispatch (gemma4 + dense, plain + profiled)
  routes to the bf16 kernels via KvAppendKv/AttentionKv/AttentionSwaKv.
- Batched/flash prefill gated off under bf16 (those kernels read an fp32
  cache); bf16 prefill uses the per-token fallback. bf16 batched/flash
  prefill is the documented Increment 1.5 follow-up.
- --kv-type CLI flag (run command) surfaces SHARPI_KV_DTYPE.
- Parity test (Qwen3-8B Q4_K, Gemma 4 E4B Q8_0, Gemma 4 12B QAT Q4_0):
  teacher-forced bf16 vs fp32 KV is top-1/top-5 stable per position with
  logit max-abs within a per-model rounding budget. 3/3 green.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…1.5a)

bf16 KV prefill no longer drops to the per-token loop for prompts within the
4096 shared-scores cap — it now uses the scalar batched bf16 kernels, matching
the fp32 batched path's structure (one launch per layer instead of N).

- attn_scale added to llm_full_seq_attention_bf16 (AttentionBatchedBf16) and
  llm_full_seq_attention_global_bf16 (AttentionBatchedWaveBf16) + wrappers, so
  Gemma 4 (scale=1.0) works; GDN callers keep the rsqrt default.
- GpuLayerBatchedTrunk routes KvAppendBatched + the SWA/global attention to the
  bf16 variants under bf16 KV.
- IsBatchedPrefillSupported no longer hard-disables bf16; the prefill gate forces
  canChunkPast4096=false under bf16 (flash/TC is still fp32-only — the 1.5b port),
  so prompts >4096 fall to the per-token loop until then.
- Parity tests: bf16 batched prefill vs bf16 per-token is top-1/top-5 stable with
  logit max-abs within the per-model budget (Qwen3-8B, Gemma E4B, Gemma 12B). 6/6
  KvDtype tests green; CLI `--kv-type bf16 -g -1` generates coherent text.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ent 1.5b)

bf16 KV prefill now uses the same tensor-core flash kernel (Tc2) as fp32 for
head_dim%64 models (Gemma 4 256, Qwen3 128), so it's full-speed at any length
and can chunk a prompt past the 4096 scalar cap — the last gap before bf16 KV
is viable as a default and before practical long-context (128K) validation.

- Templated llm_flash_attn_prefill_tc2 into a __device__ impl<KV> + two extern "C"
  thunks (fp32 + _bf16) via a sharpi_kvload overload. impl<float> is byte-identical
  to the pre-#179 kernel, so the fp32 path is unchanged (22/22 fp32+bf16 batched
  prefill tests green); the bf16 thunk decodes each K/V element on load.
- FlashAttentionPrefillTc2(..., bf16Cache:) selects the thunk; handle/load/validation
  wired.
- CudaForwardPass: bf16 batched dispatch routes head_dim%64 layers to Tc2-bf16
  (else scalar batched bf16); Bf16FlashTc2CoversAllLayers() re-enables chunked
  prefill past 4096 under bf16. attn_scale added to the scalar batched bf16 kernels
  (full_seq_attention_bf16 / _global_bf16) for Gemma's scale=1.0.
- Tests: batched-prefill parity now compares fp32-Tc2 vs bf16-Tc2 (KV-dtype the only
  variable); new chunked >4096 E4B test crosses the SWA ring boundary — top-5 stable
  (long-context bar: coherent + argmax-stable, not tight logit parity).

Tc/half2-flash bf16 thunks remain a trivial follow-up (only a non-%64 head_dim
model past 4096 would need them).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
CudaForwardPass and GpuForwardPass keep the actual KV in VRAM but still
constructed a full host-side KvCache (numLayers × maxSeqLen × kvDim × 2 floats)
that they use ONLY for the position counter (Length/TruncateTo/Reset). On Gemma 4
12B that host allocation is ~50 GB of RAM at 64K context and OOMs the host above
it — capping context well before VRAM is the limit, and masking the bf16 KV win.

Adds KvCache.CreateBookkeepingOnly (tracks the counter, allocates no buffers; the
storage accessors fault if misused) and uses it on both GPU passes. Combined with
bf16 KV this unlocks the issue's headline: Gemma 4 12B QAT, CUDA -g -1,
--kv-type bf16, runs at -c 131072 (128K) within 12 GB with coherent output
(was OOM). 5/5 SnapKV-length + bf16-parity tests green; the host accessors were
never called on the GPU paths.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
)

Notes the --kv-type bf16 + bookkeeping-only host KvCache combo that takes the
12B to 128K within 12 GB (fp32 caps ~32K), the chunked Tc2-flash bf16 prefill,
and the measured decode behavior: full-speed (~53 t/s) through 32K, dropping to
~21 t/s at 128K when VRAM saturates and the tied ~818 MB Q6_K embed/LM-head table
spills to host per token (q8_0 KV would restore it). Default stays fp32.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Block-quantized KV cache (--kv-type q8_0): ggml block_q8_0 layout
(32 int8 + 1 fp16 scale = 34 B/block, ~quarter fp32 / half bf16). At
-c 131072 on the Gemma 4 12B (-g -1) this frees ~1.1 GB over bf16,
keeping the tied ~818 MB Q6_K embed/LM-head table GPU-resident so decode
holds full speed ~53 t/s vs bf16's ~19 t/s at 128K (2.8x), same coherent
output.

- One sharpi_kvload(const block_q8_0*) overload is the single variation
  point; templated the 6 bf16 read kernels on <KV> (attention/swa
  per-token, swa-batched, full_seq, full_seq_global, Tc2-flash) with
  bf16/q8_0 thunks. The bf16 thunk is byte-identical
  (sharpi_kvload(ushort)==sharpi_bf16_to_fp32); bf16 oracle re-run green.
- New store kernels llm_kv_append_q8_0/_batched quantize per 32-lane warp
  (amax -> d=amax/127 -> rintf clamp +-127, mirrors llm_quantize_q8_1).
- CudaBackend.Allocate now sizes via DTypeInfo.ByteSize (identical for
  scalar, adds block-quantized) so the cache allocates as DType.Q8_0.
- ParseKvDType q8_0; dispatch + batched-prefill + chunked->4096 gates
  generalized to kvNarrowed; CLI --kv-type q8_0; same TQ/explicit-SnapKV
  rejection + auto-SnapKV-off as bf16.
- Tests: q8_0 argmax-stable + batched + chunked parity (Qwen3-8B, Gemma
  E4B, Gemma 12B); bf16 oracle preserved via shared kvDtype helpers.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ts (#179)

Addresses pr-review-toolkit findings on Increment 2:

- Silent-failure (HIGH): fail loud at the q8_0 KV allocation if any layer's
  kvDim isn't a multiple of 32. DTypeInfo.ByteSize's count/32 sizing and the
  store kernels' per-warp amax assume it; every dense GGUF head_dim satisfies
  it, but throw a clear NotSupportedException instead of silently
  under-allocating + corrupting on a future geometry.
- Tests: add Gemma4_12B_Q8ChunkedPrefill_MatchesFp32 — the 128K driving model
  (k_eq_v globals) crossing the chunk/SWA-ring boundary under q8_0, the headline
  config that the E4B-only chunked test didn't cover. Add a CPU-only
  DTypeInfo.ByteSize(.., Q8_0)==34 unit assert guarding the KV-cache sizing.
- Comments: refresh 6 kernel headers that still read bf16-specific / "two
  thunks" / "default KV dtype" after the <KV> templating; tighten the
  block_q8_0 alignment note to the one load-bearing invariant (kv_dim % 32).

All green: 12B q8 chunked parity (12m57s), ByteSize unit, prior q8 + bf16
oracles unchanged.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for narrowed KV-cache element types (bf16 and q8_0) in the CUDA backend to significantly reduce VRAM footprint and enable longer context lengths. It adds command-line options, templates the CUDA attention and flash kernels to support the new dtypes, and implements a bookkeeping-only host KvCache to prevent host OOM errors. Additionally, comprehensive parity and long-context tests are added. Feedback on the changes points out a potential NaN generation issue in the q8_0 quantization kernel when dealing with extremely small subnormal floats, which could cause zero values to be incorrectly quantized to -127, and suggests using a small threshold instead of a strict zero check.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +674 to +676
float d = a / 127.f;
float invd = (d == 0.f) ? 0.f : (1.f / d);
int q = (int)rintf(val * invd);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When d is extremely small (e.g., a subnormal float in the range [1.4e-45, 2.94e-39]), 1.f / d will overflow to infinity. For any lane where val is 0.f, this results in 0.f * infinity = NaN. Casting NaN to int yields an indefinite integer value (typically -2147483648), which gets clamped to -127 instead of the expected 0.\n\nTo prevent NaN generation and ensure robust quantization for extremely small values, consider using a small threshold (e.g., 1e-30f) instead of a strict 0.f check.

    float d    = a / 127.f;
    float invd = (d < 1e-30f) ? 0.f : (1.f / d);
    int   q    = (int)rintf(val * invd);

The CLI's --kv-type / SHARPI_KV_DTYPE had no server analogue. Add it to the
SharpInferenceServerOptions surface so an operator can pick the narrowed KV
cache from appsettings.json the same way they pick TurboQuant or the backend.

- SharpInferenceServerOptions.KvType (string?, mirrors --kv-type: fp32/bf16/q8_0).
- InferenceEngineLoader forwards it to SHARPI_KV_DTYPE before BuildForwardPass
  (CudaForwardPass reads the env var in its ctor), only when explicitly set so
  an externally-set env var still works; the value is validated at model load.
- Server.Host maps the SHARPI_KV_DTYPE env override to the option, mirroring the
  SHARPI_MODEL / SHARPI_BACKEND pattern.
- Tests: KvType default (null) + appsettings binding ("q8_0").

Has effect only on the CUDA dense path; ignored by CPU/Vulkan and rejected with
TurboQuant / explicit SnapKV, same as the CLI.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@pekkah

pekkah commented Jun 9, 2026

Copy link
Copy Markdown
Owner Author

Added the narrowed-KV option to the API server too (4bdbd83), so it isn't CLI-only:

  • SharpInferenceServerOptions.KvType (fp32/bf16/q8_0) — bindable from appsettings.json (SharpInference:KvType), inline Configure, or the SHARPI_KV_DTYPE env override (mirrors the SHARPI_MODEL/SHARPI_BACKEND pattern).
  • InferenceEngineLoader forwards it to SHARPI_KV_DTYPE before building the forward pass (only when explicitly set, so an externally-set env var still works); validated at model load.
  • Has effect only on the CUDA dense path; ignored by CPU/Vulkan and rejected with TurboQuant / explicit SnapKV — same semantics as the CLI --kv-type.
  • Tests: KvType default + appsettings binding (q8_0).

pekkah and others added 2 commits June 9, 2026 20:49
Trivial review follow-ups (defense-in-depth, no behavior change for valid input):

- FlashAttentionPrefillTc2's dtype selector now routes fp32 explicitly and
  throws ArgumentOutOfRangeException on an unexpected dtype, instead of a
  `_ => fp32` fall-through that would silently stride a narrowed (q8_0) cache
  through the fp32 kernel if a future KV dtype were added without updating here.
- EnsureImageKernels logs the full exception (ToString) not just Message, so a
  GetKernelFunc bind failure surfaces which kernel didn't load.

Qwen3-8B q8_0 + bf16 batched prefill (flash Tc2 path) still green.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Addresses the gemini-code-assist review on PR #184: in sharpi_q8_append_one a
sub-block whose amax is a subnormal float makes d=amax/127 subnormal, so 1/d
overflows to +inf and a zero lane's 0*inf = NaN → (int)NaN clamps to -127
instead of 0. Use a 1e-30 threshold instead of the strict d==0 check (matching
the suggestion). Far below any real KV scale (d ≈ amax/127 ≈ 1e-3..1e-1), so
real blocks are unaffected and parity is unchanged; an all-near-zero block now
quantizes to all-zeros, correct to within its negligible magnitude.

Qwen3-8B q8_0 per-token + batched parity still green.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@pekkah

pekkah commented Jun 9, 2026

Copy link
Copy Markdown
Owner Author

Good catch on the subnormal-scale edge — applied in 003bacf. A sub-block whose amax is subnormal makes d=amax/127 subnormal, 1/d overflows to +inf, and a zero lane's 0*inf = NaN → (int)NaN clamps to -127. Switched to the d < 1e-30f threshold as suggested. It's far below any real KV scale (d ≈ amax/127 ≈ 1e-3..1e-1), so real blocks are unaffected and the parity oracles are unchanged; an all-near-zero block now correctly quantizes to all-zeros.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant