bf16 + q8_0 KV cache for the dense CUDA path → 128K on Gemma 4 12B in 12 GB (#179)#184
Conversation
Adds opt-in half-width (bf16) KV storage to the dense CudaForwardPass, halving the KV footprint (~2x context headroom) on a 12 GB card. Arithmetic stays fp32 in the kernels — only the cache store is narrowed — so decode is argmax-stable vs fp32 KV. - SHARPI_KV_DTYPE (fp32|bf16) wired into CudaForwardPass; default fp32 (opt-in until long-context validated). Rejects bf16+TQ and bf16+explicit SnapKV; disables auto-SnapKV under bf16. - New SWA bf16 kernels llm_attention_swa_bf16 / _batched_bf16 (compose the SWA windowing/ring with bf16 K/V decode); backend wrappers + handles. - attn_scale added to llm_attention_bf16 (+ wrapper) so Gemma 4 globals (scale=1.0) work with bf16 KV; GDN callers keep the rsqrt default. - Graph-capture position tracking added to KvAppendBf16/AttentionBf16 so the full bf16 decode region is CUDA-graph-safe (GDN path never captures, so its behavior is unchanged). - Per-token decode/prefill dispatch (gemma4 + dense, plain + profiled) routes to the bf16 kernels via KvAppendKv/AttentionKv/AttentionSwaKv. - Batched/flash prefill gated off under bf16 (those kernels read an fp32 cache); bf16 prefill uses the per-token fallback. bf16 batched/flash prefill is the documented Increment 1.5 follow-up. - --kv-type CLI flag (run command) surfaces SHARPI_KV_DTYPE. - Parity test (Qwen3-8B Q4_K, Gemma 4 E4B Q8_0, Gemma 4 12B QAT Q4_0): teacher-forced bf16 vs fp32 KV is top-1/top-5 stable per position with logit max-abs within a per-model rounding budget. 3/3 green. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…1.5a) bf16 KV prefill no longer drops to the per-token loop for prompts within the 4096 shared-scores cap — it now uses the scalar batched bf16 kernels, matching the fp32 batched path's structure (one launch per layer instead of N). - attn_scale added to llm_full_seq_attention_bf16 (AttentionBatchedBf16) and llm_full_seq_attention_global_bf16 (AttentionBatchedWaveBf16) + wrappers, so Gemma 4 (scale=1.0) works; GDN callers keep the rsqrt default. - GpuLayerBatchedTrunk routes KvAppendBatched + the SWA/global attention to the bf16 variants under bf16 KV. - IsBatchedPrefillSupported no longer hard-disables bf16; the prefill gate forces canChunkPast4096=false under bf16 (flash/TC is still fp32-only — the 1.5b port), so prompts >4096 fall to the per-token loop until then. - Parity tests: bf16 batched prefill vs bf16 per-token is top-1/top-5 stable with logit max-abs within the per-model budget (Qwen3-8B, Gemma E4B, Gemma 12B). 6/6 KvDtype tests green; CLI `--kv-type bf16 -g -1` generates coherent text. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ent 1.5b) bf16 KV prefill now uses the same tensor-core flash kernel (Tc2) as fp32 for head_dim%64 models (Gemma 4 256, Qwen3 128), so it's full-speed at any length and can chunk a prompt past the 4096 scalar cap — the last gap before bf16 KV is viable as a default and before practical long-context (128K) validation. - Templated llm_flash_attn_prefill_tc2 into a __device__ impl<KV> + two extern "C" thunks (fp32 + _bf16) via a sharpi_kvload overload. impl<float> is byte-identical to the pre-#179 kernel, so the fp32 path is unchanged (22/22 fp32+bf16 batched prefill tests green); the bf16 thunk decodes each K/V element on load. - FlashAttentionPrefillTc2(..., bf16Cache:) selects the thunk; handle/load/validation wired. - CudaForwardPass: bf16 batched dispatch routes head_dim%64 layers to Tc2-bf16 (else scalar batched bf16); Bf16FlashTc2CoversAllLayers() re-enables chunked prefill past 4096 under bf16. attn_scale added to the scalar batched bf16 kernels (full_seq_attention_bf16 / _global_bf16) for Gemma's scale=1.0. - Tests: batched-prefill parity now compares fp32-Tc2 vs bf16-Tc2 (KV-dtype the only variable); new chunked >4096 E4B test crosses the SWA ring boundary — top-5 stable (long-context bar: coherent + argmax-stable, not tight logit parity). Tc/half2-flash bf16 thunks remain a trivial follow-up (only a non-%64 head_dim model past 4096 would need them). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
CudaForwardPass and GpuForwardPass keep the actual KV in VRAM but still constructed a full host-side KvCache (numLayers × maxSeqLen × kvDim × 2 floats) that they use ONLY for the position counter (Length/TruncateTo/Reset). On Gemma 4 12B that host allocation is ~50 GB of RAM at 64K context and OOMs the host above it — capping context well before VRAM is the limit, and masking the bf16 KV win. Adds KvCache.CreateBookkeepingOnly (tracks the counter, allocates no buffers; the storage accessors fault if misused) and uses it on both GPU passes. Combined with bf16 KV this unlocks the issue's headline: Gemma 4 12B QAT, CUDA -g -1, --kv-type bf16, runs at -c 131072 (128K) within 12 GB with coherent output (was OOM). 5/5 SnapKV-length + bf16-parity tests green; the host accessors were never called on the GPU paths. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
) Notes the --kv-type bf16 + bookkeeping-only host KvCache combo that takes the 12B to 128K within 12 GB (fp32 caps ~32K), the chunked Tc2-flash bf16 prefill, and the measured decode behavior: full-speed (~53 t/s) through 32K, dropping to ~21 t/s at 128K when VRAM saturates and the tied ~818 MB Q6_K embed/LM-head table spills to host per token (q8_0 KV would restore it). Default stays fp32. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Block-quantized KV cache (--kv-type q8_0): ggml block_q8_0 layout (32 int8 + 1 fp16 scale = 34 B/block, ~quarter fp32 / half bf16). At -c 131072 on the Gemma 4 12B (-g -1) this frees ~1.1 GB over bf16, keeping the tied ~818 MB Q6_K embed/LM-head table GPU-resident so decode holds full speed ~53 t/s vs bf16's ~19 t/s at 128K (2.8x), same coherent output. - One sharpi_kvload(const block_q8_0*) overload is the single variation point; templated the 6 bf16 read kernels on <KV> (attention/swa per-token, swa-batched, full_seq, full_seq_global, Tc2-flash) with bf16/q8_0 thunks. The bf16 thunk is byte-identical (sharpi_kvload(ushort)==sharpi_bf16_to_fp32); bf16 oracle re-run green. - New store kernels llm_kv_append_q8_0/_batched quantize per 32-lane warp (amax -> d=amax/127 -> rintf clamp +-127, mirrors llm_quantize_q8_1). - CudaBackend.Allocate now sizes via DTypeInfo.ByteSize (identical for scalar, adds block-quantized) so the cache allocates as DType.Q8_0. - ParseKvDType q8_0; dispatch + batched-prefill + chunked->4096 gates generalized to kvNarrowed; CLI --kv-type q8_0; same TQ/explicit-SnapKV rejection + auto-SnapKV-off as bf16. - Tests: q8_0 argmax-stable + batched + chunked parity (Qwen3-8B, Gemma E4B, Gemma 12B); bf16 oracle preserved via shared kvDtype helpers. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ts (#179) Addresses pr-review-toolkit findings on Increment 2: - Silent-failure (HIGH): fail loud at the q8_0 KV allocation if any layer's kvDim isn't a multiple of 32. DTypeInfo.ByteSize's count/32 sizing and the store kernels' per-warp amax assume it; every dense GGUF head_dim satisfies it, but throw a clear NotSupportedException instead of silently under-allocating + corrupting on a future geometry. - Tests: add Gemma4_12B_Q8ChunkedPrefill_MatchesFp32 — the 128K driving model (k_eq_v globals) crossing the chunk/SWA-ring boundary under q8_0, the headline config that the E4B-only chunked test didn't cover. Add a CPU-only DTypeInfo.ByteSize(.., Q8_0)==34 unit assert guarding the KV-cache sizing. - Comments: refresh 6 kernel headers that still read bf16-specific / "two thunks" / "default KV dtype" after the <KV> templating; tighten the block_q8_0 alignment note to the one load-bearing invariant (kv_dim % 32). All green: 12B q8 chunked parity (12m57s), ByteSize unit, prior q8 + bf16 oracles unchanged. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces support for narrowed KV-cache element types (bf16 and q8_0) in the CUDA backend to significantly reduce VRAM footprint and enable longer context lengths. It adds command-line options, templates the CUDA attention and flash kernels to support the new dtypes, and implements a bookkeeping-only host KvCache to prevent host OOM errors. Additionally, comprehensive parity and long-context tests are added. Feedback on the changes points out a potential NaN generation issue in the q8_0 quantization kernel when dealing with extremely small subnormal floats, which could cause zero values to be incorrectly quantized to -127, and suggests using a small threshold instead of a strict zero check.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| float d = a / 127.f; | ||
| float invd = (d == 0.f) ? 0.f : (1.f / d); | ||
| int q = (int)rintf(val * invd); |
There was a problem hiding this comment.
When d is extremely small (e.g., a subnormal float in the range [1.4e-45, 2.94e-39]), 1.f / d will overflow to infinity. For any lane where val is 0.f, this results in 0.f * infinity = NaN. Casting NaN to int yields an indefinite integer value (typically -2147483648), which gets clamped to -127 instead of the expected 0.\n\nTo prevent NaN generation and ensure robust quantization for extremely small values, consider using a small threshold (e.g., 1e-30f) instead of a strict 0.f check.
float d = a / 127.f;
float invd = (d < 1e-30f) ? 0.f : (1.f / d);
int q = (int)rintf(val * invd);The CLI's --kv-type / SHARPI_KV_DTYPE had no server analogue. Add it to the
SharpInferenceServerOptions surface so an operator can pick the narrowed KV
cache from appsettings.json the same way they pick TurboQuant or the backend.
- SharpInferenceServerOptions.KvType (string?, mirrors --kv-type: fp32/bf16/q8_0).
- InferenceEngineLoader forwards it to SHARPI_KV_DTYPE before BuildForwardPass
(CudaForwardPass reads the env var in its ctor), only when explicitly set so
an externally-set env var still works; the value is validated at model load.
- Server.Host maps the SHARPI_KV_DTYPE env override to the option, mirroring the
SHARPI_MODEL / SHARPI_BACKEND pattern.
- Tests: KvType default (null) + appsettings binding ("q8_0").
Has effect only on the CUDA dense path; ignored by CPU/Vulkan and rejected with
TurboQuant / explicit SnapKV, same as the CLI.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
Added the narrowed-KV option to the API server too (
|
Trivial review follow-ups (defense-in-depth, no behavior change for valid input): - FlashAttentionPrefillTc2's dtype selector now routes fp32 explicitly and throws ArgumentOutOfRangeException on an unexpected dtype, instead of a `_ => fp32` fall-through that would silently stride a narrowed (q8_0) cache through the fp32 kernel if a future KV dtype were added without updating here. - EnsureImageKernels logs the full exception (ToString) not just Message, so a GetKernelFunc bind failure surfaces which kernel didn't load. Qwen3-8B q8_0 + bf16 batched prefill (flash Tc2 path) still green. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Addresses the gemini-code-assist review on PR #184: in sharpi_q8_append_one a sub-block whose amax is a subnormal float makes d=amax/127 subnormal, so 1/d overflows to +inf and a zero lane's 0*inf = NaN → (int)NaN clamps to -127 instead of 0. Use a 1e-30 threshold instead of the strict d==0 check (matching the suggestion). Far below any real KV scale (d ≈ amax/127 ≈ 1e-3..1e-1), so real blocks are unaffected and parity is unchanged; an all-near-zero block now quantizes to all-zeros, correct to within its negligible magnitude. Qwen3-8B q8_0 per-token + batched parity still green. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
Good catch on the subnormal-scale edge — applied in |
Adds opt-in narrowed KV-cache dtypes (
--kv-type bf16|q8_0/SHARPI_KV_DTYPE) to the dense CUDA forward pass, plus a host-allocation fix that was the real context ceiling. Takes the Gemma 4 12B QAT (-g -1) from a ~32K ceiling to-c 131072(128K) within 12 GB, coherent. Default KV stays fp32; both narrowed dtypes are opt-in.Result (Gemma 4 12B QAT Q4_0, RTX 4070 Ti 12 GB,
-g -1)cudaMalloc-fails at 64K)At 128K the tied ~818 MB Q6_K embed/LM-head table can't stay GPU-resident under bf16's KV (~350 MB free) and spills to host per token → decode cliffs. q8_0's block-quantized cache (~¼ fp32, ~½ bf16) frees ~1.1 GB more, keeps the table resident, and holds full speed. Same coherent output across all three dtypes.
What's in here (squashes #179 increments 1 → 2)
--kv-typeCLI flag (rejects bf16+TQ / explicit-SnapKV).KvCache.CreateBookkeepingOnly: the GPU passes kept the real KV in VRAM but still allocated a full host KV buffer used only for the position counter (~50 GB at 64K on the 12B → managedOutOfMemoryException). This was the hidden ceiling, masking the bf16 win.block_q8_0layout (32 int8 + fp16 scale = 34 B/block). Onesharpi_kvload(const block_q8_0*)overload is the single variation point; the 6 bf16 read kernels were templated on<KV>with bf16/q8_0 thunks (bf16 thunk byte-identical, sincesharpi_kvload(unsigned short) == sharpi_bf16_to_fp32). New store kernels quantize per 32-lane warp (amax → d=amax/127 → rintfclamp ±127, mirroringllm_quantize_q8_1).CudaBackend.Allocatenow sizes viaDTypeInfo.ByteSize(identical for scalar, adds block-quantized).Validation
feedback_cross_backend_parity_test):DTypeInfo.ByteSize(.., Q8_0)==34CPU unit assert.-c 131072on the 12B: q8_0 53.3 t/s vs bf16 18.8 t/s, both coherent.Review (pr-review-toolkit)
code-reviewer: no critical/important issues. Addressed: a fail-loud guard for the q8_0
kvDim % 32invariant (was a latent silent under-allocation; all current models satisfy it), the 12B chunked test above, the ByteSize unit test, and 6 stale kernel-header comments. Knowingly deferred: the_ => fp32dtype-selector fall-through (future-proofing only —_kvDTypeis centrally validated to {fp32,bf16,q8_0}) and a pre-existingEnsureImageKernelslogging nit.Notes / follow-ups
%64head_dim model (none on disk) — same latent-but-unexercised status as their bf16 twins.🤖 Generated with Claude Code