perf(gemma4): GPU-side PLE pre-pass dequant for batched prefill (#247 item 1)#256
Conversation
…item 1) The Gemma 4 batched-prefill PLE pre-pass (BuildPerLayerProjectionsBatched) dequanted each token's per_layer_token_embd row on the CPU (Parallel.For over N x stackedDim elements) and then uploaded the full f32 rows -- ~30% of batched prefill per the original #141 profiling. Move the dequant to the GPU: gather the chunk's raw packed rows on the CPU (a memcpy, no FP math), upload the 4x-smaller quant bytes, then dequant on-device into _bpPleRowAll via new contiguous-row kernels. The PLE table stays CPU-resident (TierPlanner budget and auto-context untouched, no extra VRAM), so this is a no-regression win on the 12 GB card. The 12B has no PLE table (embedding_length_per_layer_input=0); PLE is E2B/E4B only (Q8_0 or Q6_K). - New kernels llm_dequant_rows_q8_0 / llm_dequant_rows_q6k (grid.x=n_rows, block=256, row_dim % 256); backend DequantRowsQ8_0/Q6K + UploadRawInto. - Bit-identical to CPU Dequantize.ToFloat32 (both (d*scale)*q, exact cvt.f32.f16) so the GemmOff bit-exact prefill oracle still holds. New CudaDequantRowsTests proves it directly for Q8_0 + Q6_K incl. the real 10752-wide row. - Kill-switch SHARPI_PLE_GPU_DEQUANT=0 forces the CPU fallback (also the path for an F32 table or a quant without a dequant-rows kernel). Measured (E4B-Q8_0, -g -1, warm): PLE phase 80->63 ms @n=4096 (~21%), 8->4 ms @n=972; argmax-stable and bit-exact vs the per-token loop. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request implements GPU-side PLE (Per-Layer Embedding) pre-pass dequantization for Gemma-4 (issue #247) by introducing CUDA kernels for Q8_0 and Q6_K packed rows. This optimization offloads the expensive dequantization math from the CPU to the GPU and reduces host-to-device upload sizes by 4x. Feedback on the changes suggests adding an early return for empty source spans in UploadRawInto to prevent potential null pointer issues, optimizing the hot Parallel.For loop in BuildPerLayerProjectionsBatched by pinning the array and using Buffer.MemoryCopy instead of slicing spans, and using BitConverter.HalfToUInt16Bits in unit tests to avoid heap allocations.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| public void UploadRawInto(Tensor dst, ReadOnlySpan<byte> src) | ||
| { | ||
| if (!_devPtrs.TryGetValue(dst.Handle, out var entry)) | ||
| throw new InvalidOperationException($"UploadRawInto: handle {dst.Handle} not registered."); | ||
| if ((nuint)src.Length > entry.byteSize) | ||
| throw new ArgumentException($"UploadRawInto: source ({src.Length} bytes) exceeds destination capacity ({entry.byteSize} bytes)."); | ||
| fixed (byte* s = src) | ||
| UploadViaStaging(entry.devPtr, s, (nuint)src.Length); | ||
| } |
There was a problem hiding this comment.
If src is empty, calling fixed (byte* s = src) will result in s being null. Depending on how UploadViaStaging handles null pointers and zero lengths, this could lead to undefined behavior or unnecessary overhead. Adding an early return when src is empty is a safer and more defensive approach.
public void UploadRawInto(Tensor dst, ReadOnlySpan<byte> src)
{
if (src.IsEmpty)
return;
if (!_devPtrs.TryGetValue(dst.Handle, out var entry))
throw new InvalidOperationException($"UploadRawInto: handle {dst.Handle} not registered.");
if ((nuint)src.Length > entry.byteSize)
throw new ArgumentException($"UploadRawInto: source ({src.Length} bytes) exceeds destination capacity ({entry.byteSize} bytes).");
fixed (byte* s = src)
UploadViaStaging(entry.devPtr, s, (nuint)src.Length);
}| int bpr = _pleBytesPerRow; | ||
| var qhost = _bpPleQuantHost!; | ||
| System.Threading.Tasks.Parallel.For(0, N, i => | ||
| { | ||
| byte* rowPtr = basePtr + (long)tokens[i] * bpr; | ||
| new ReadOnlySpan<byte>(rowPtr, bpr) | ||
| .CopyTo(new Span<byte>(qhost).Slice(i * bpr, bpr)); | ||
| }); |
There was a problem hiding this comment.
Creating Span and ReadOnlySpan instances and slicing them inside a hot Parallel.For loop introduces unnecessary overhead. Since Parallel.For is a blocking call, we can pin the qhost array once outside the loop using a fixed statement and perform a direct memory copy using Buffer.MemoryCopy. This avoids span allocation/slicing overhead and is significantly faster.
int bpr = _pleBytesPerRow;
var qhost = _bpPleQuantHost!;
fixed (byte* qhostPtr = qhost)
{
System.Threading.Tasks.Parallel.For(0, N, i =>
{
byte* rowPtr = basePtr + (long)tokens[i] * bpr;
Buffer.MemoryCopy(rowPtr, qhostPtr + (long)i * bpr, bpr, bpr);
});
}| private static ushort HalfToUshort(Half h) => | ||
| BitConverter.ToUInt16(BitConverter.GetBytes(h), 0); |
There was a problem hiding this comment.
Using BitConverter.GetBytes(h) on a Half is highly inefficient because it allocates a new byte[] array on the heap on every call. Since this method is called thousands of times in a loop during testing, it creates significant GC pressure. Using BitConverter.HalfToUInt16Bits(h) completely avoids allocations and is extremely fast.
private static ushort HalfToUshort(Half h) =>
BitConverter.HalfToUInt16Bits(h);|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request implements GPU-side dequantization for Gemma-4 PLE pre-pass rows (supporting Q8_0 and Q6_K formats), moving the work off the CPU and reducing host-to-device upload sizes by 4x. It also adds comprehensive unit tests to ensure bit-identical results with the CPU fallback. The feedback suggests adding a defensive guard for empty token lists to prevent an ArgumentOutOfRangeException, and hoisting Span creations out of the Parallel.For loops in both the GPU and CPU fallback paths to optimize performance.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| private void BuildPerLayerProjectionsBatched(IReadOnlyList<int> tokens) | ||
| { | ||
| int N = tokens.Count, L = _hp.NumLayers; | ||
| int stackedDim = L * _pleWidth; |
There was a problem hiding this comment.
Consider adding a defensive guard to handle the case where tokens is empty (N == 0). If N is 0, executing the GPU dequantization path will call DequantRowsQ8_0 or DequantRowsQ6K with nRows = 0, which throws an ArgumentOutOfRangeException.
private void BuildPerLayerProjectionsBatched(IReadOnlyList<int> tokens)
{
int N = tokens.Count;
if (N == 0) return;
int L = _hp.NumLayers;
int stackedDim = L * _pleWidth;| int bpr = _pleBytesPerRow; | ||
| var qhost = _bpPleQuantHost!; | ||
| System.Threading.Tasks.Parallel.For(0, N, i => | ||
| { | ||
| byte* rowPtr = basePtr + (long)tokens[i] * bpr; | ||
| new ReadOnlySpan<byte>(rowPtr, bpr) | ||
| .CopyTo(new Span<byte>(qhost).Slice(i * bpr, bpr)); | ||
| }); |
There was a problem hiding this comment.
To avoid repeatedly creating a new Span<byte> wrapper around qhost on every iteration of the parallel loop, you can hoist the span creation outside of the Parallel.For loop.
int bpr = _pleBytesPerRow;
var qhost = _bpPleQuantHost!;
Span<byte> qhostSpan = qhost;
System.Threading.Tasks.Parallel.For(0, N, i =>
{
byte* rowPtr = basePtr + (long)tokens[i] * bpr;
new ReadOnlySpan<byte>(rowPtr, bpr)
.CopyTo(qhostSpan.Slice(i * bpr, bpr));
});| int bytesPerRow = _pleBytesPerRow; | ||
| var host = _bpPleRowHostAll!; | ||
| var dtype = _pleDType; | ||
| System.Threading.Tasks.Parallel.For(0, N, i => | ||
| { | ||
| byte* rowPtr = basePtr + (long)tokens[i] * bytesPerRow; | ||
| var dst = new Span<float>(host).Slice(i * stackedDim, stackedDim); | ||
| if (dtype == DType.Float32) | ||
| new ReadOnlySpan<float>((float*)rowPtr, stackedDim).CopyTo(dst); | ||
| else | ||
| Dequantize.ToFloat32(new ReadOnlySpan<byte>(rowPtr, bytesPerRow), dst, dtype, stackedDim); | ||
| }); |
There was a problem hiding this comment.
Similarly, in the CPU fallback path, you can hoist the creation of the Span<float> wrapper around host outside of the Parallel.For loop to avoid redundant allocations on each iteration.
int bytesPerRow = _pleBytesPerRow;
var host = _bpPleRowHostAll!;
Span<float> hostSpan = host;
var dtype = _pleDType;
System.Threading.Tasks.Parallel.For(0, N, i =>
{
byte* rowPtr = basePtr + (long)tokens[i] * bytesPerRow;
var dst = hostSpan.Slice(i * stackedDim, stackedDim);
if (dtype == DType.Float32)
new ReadOnlySpan<float>((float*)rowPtr, stackedDim).CopyTo(dst);
else
Dequantize.ToFloat32(new ReadOnlySpan<byte>(rowPtr, bytesPerRow), dst, dtype, stackedDim);
});…her bound Review-cycle follow-ups to #256: - BuildQ8_0Rows now spans the full int8 range [-128,127] (was [-127,127]), so the bit-identity assertion exercises the -128 quant too. - Document that N*bpr stays within int at the 4096 chunk cap. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…in test Address Gemini review on #256: - UploadRawInto returns early on an empty span (after validation) to avoid fixed → null-ptr. - Test HalfToUshort uses BitConverter.HalfToUInt16Bits (no per-call byte[]). Not applied: the Parallel.For span/pointer hoisting suggestions don't compile (Span<byte> is a ref struct, can't be captured in the lambda; new Span(array) is already allocation-free), and the N==0 guard targets an unreachable case (empty prompts are never prefilled). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
Thanks for the review. Disposition of the findings: Applied (commit 8af6870):
Also covered the full int8 range incl. Not applied (with rationale):
|
Implements item #1 of the #247 Gemma 4 prefill attack plan — the PLE pre-pass dequant — and reports findings on the rest.
What changed
The Gemma 4 batched-prefill PLE pre-pass (
BuildPerLayerProjectionsBatched) dequanted each token'sper_layer_token_embdrow on the CPU (Parallel.Forover N×10752 elements) and uploaded the full f32 rows. This moves the dequant to the GPU:_bpPleRowAllvia new contiguous-row kernels.The PLE table stays CPU-resident (TierPlanner budget + auto-context untouched, no extra VRAM), so this is a no-regression change on the 12 GB card. New kernels
llm_dequant_rows_q8_0/llm_dequant_rows_q6k; backendDequantRowsQ8_0/Q6K+UploadRawInto. Kill-switchSHARPI_PLE_GPU_DEQUANT=0(also the F32-table fallback).Correctness
Dequantize.ToFloat32(both(d·scale)·q, exactcvt.f32.f16), so the strictGemma4_E4B_BatchedPrefill_GemmOff_MatchesSequentialBitExactoracle still holds.CudaDequantRowsTestsproves bit-identity directly for Q8_0 + Q6_K incl. the real 10752-wide row (6/6 pass).Gemma4CudaBatchedPrefillTestssuite passes 8/8 (incl. the bit-exact oracle, flash/GEMM/MMQ argmax oracles, past-window, chunked >4096).TreatWarningsAsErrors).Measured (E4B-Q8_0,
-g -1, warm)PLE phase 80 → 63 ms @ N=4096 (~21%), 8 → 4 ms @ N=972 — at zero VRAM cost, plus reduced host RAM and PCIe traffic. The
plephase is dominated by the unavoidable per-layer-proj GEMM, so the dequant was the removable slice; the original "~30% of batched prefill" (#141) doesn't reproduce on the current warm, layer-optimized path.Notes on the rest of #247
Investigated and reported in the issue:
FlashAttentionPrefillTc2._M-quant trunks on multi-chunk/repeated prefills, none available to measure.A 12B note for the issue: the 12B has no PLE table (
embedding_length_per_layer_input=0) — PLE is E2B/E4B only.🤖 Generated with Claude Code