Skip to content

perf(gemma4): GPU-side PLE pre-pass dequant for batched prefill (#247 item 1)#256

Merged
pekkah merged 3 commits into
masterfrom
claude/gemma4-prefill-optimization-h7syay
Jun 15, 2026
Merged

perf(gemma4): GPU-side PLE pre-pass dequant for batched prefill (#247 item 1)#256
pekkah merged 3 commits into
masterfrom
claude/gemma4-prefill-optimization-h7syay

Conversation

@pekkah

@pekkah pekkah commented Jun 15, 2026

Copy link
Copy Markdown
Owner

Implements item #1 of the #247 Gemma 4 prefill attack plan — the PLE pre-pass dequant — and reports findings on the rest.

What changed

The Gemma 4 batched-prefill PLE pre-pass (BuildPerLayerProjectionsBatched) dequanted each token's per_layer_token_embd row on the CPU (Parallel.For over N×10752 elements) and uploaded the full f32 rows. This moves the dequant to the GPU:

  • Gather the chunk's raw packed rows on the CPU (a memcpy, no FP math).
  • Upload the 4×-smaller quant bytes (vs f32 rows).
  • Dequant on-device into _bpPleRowAll via new contiguous-row kernels.

The PLE table stays CPU-resident (TierPlanner budget + auto-context untouched, no extra VRAM), so this is a no-regression change on the 12 GB card. New kernels llm_dequant_rows_q8_0 / llm_dequant_rows_q6k; backend DequantRowsQ8_0/Q6K + UploadRawInto. Kill-switch SHARPI_PLE_GPU_DEQUANT=0 (also the F32-table fallback).

Correctness

  • GPU dequant is bit-identical to CPU Dequantize.ToFloat32 (both (d·scale)·q, exact cvt.f32.f16), so the strict Gemma4_E4B_BatchedPrefill_GemmOff_MatchesSequentialBitExact oracle still holds.
  • New CudaDequantRowsTests proves bit-identity directly for Q8_0 + Q6_K incl. the real 10752-wide row (6/6 pass).
  • Full Gemma4CudaBatchedPrefillTests suite passes 8/8 (incl. the bit-exact oracle, flash/GEMM/MMQ argmax oracles, past-window, chunked >4096).
  • Solution builds clean (0 warnings under TreatWarningsAsErrors).

Measured (E4B-Q8_0, -g -1, warm)

PLE phase 80 → 63 ms @ N=4096 (~21%), 8 → 4 ms @ N=972 — at zero VRAM cost, plus reduced host RAM and PCIe traffic. The ple phase is dominated by the unavoidable per-layer-proj GEMM, so the dequant was the removable slice; the original "~30% of batched prefill" (#141) doesn't reproduce on the current warm, layer-optimized path.

Notes on the rest of #247

Investigated and reported in the issue:

A 12B note for the issue: the 12B has no PLE table (embedding_length_per_layer_input=0) — PLE is E2B/E4B only.

🤖 Generated with Claude Code

…item 1)

The Gemma 4 batched-prefill PLE pre-pass (BuildPerLayerProjectionsBatched)
dequanted each token's per_layer_token_embd row on the CPU (Parallel.For over
N x stackedDim elements) and then uploaded the full f32 rows -- ~30% of batched
prefill per the original #141 profiling.

Move the dequant to the GPU: gather the chunk's raw packed rows on the CPU (a
memcpy, no FP math), upload the 4x-smaller quant bytes, then dequant on-device
into _bpPleRowAll via new contiguous-row kernels. The PLE table stays
CPU-resident (TierPlanner budget and auto-context untouched, no extra VRAM), so
this is a no-regression win on the 12 GB card. The 12B has no PLE table
(embedding_length_per_layer_input=0); PLE is E2B/E4B only (Q8_0 or Q6_K).

- New kernels llm_dequant_rows_q8_0 / llm_dequant_rows_q6k (grid.x=n_rows,
  block=256, row_dim % 256); backend DequantRowsQ8_0/Q6K + UploadRawInto.
- Bit-identical to CPU Dequantize.ToFloat32 (both (d*scale)*q, exact cvt.f32.f16)
  so the GemmOff bit-exact prefill oracle still holds. New CudaDequantRowsTests
  proves it directly for Q8_0 + Q6_K incl. the real 10752-wide row.
- Kill-switch SHARPI_PLE_GPU_DEQUANT=0 forces the CPU fallback (also the path for
  an F32 table or a quant without a dequant-rows kernel).

Measured (E4B-Q8_0, -g -1, warm): PLE phase 80->63 ms @n=4096 (~21%),
8->4 ms @n=972; argmax-stable and bit-exact vs the per-token loop.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements GPU-side PLE (Per-Layer Embedding) pre-pass dequantization for Gemma-4 (issue #247) by introducing CUDA kernels for Q8_0 and Q6_K packed rows. This optimization offloads the expensive dequantization math from the CPU to the GPU and reduces host-to-device upload sizes by 4x. Feedback on the changes suggests adding an early return for empty source spans in UploadRawInto to prevent potential null pointer issues, optimizing the hot Parallel.For loop in BuildPerLayerProjectionsBatched by pinning the array and using Buffer.MemoryCopy instead of slicing spans, and using BitConverter.HalfToUInt16Bits in unit tests to avoid heap allocations.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +838 to +846
public void UploadRawInto(Tensor dst, ReadOnlySpan<byte> src)
{
if (!_devPtrs.TryGetValue(dst.Handle, out var entry))
throw new InvalidOperationException($"UploadRawInto: handle {dst.Handle} not registered.");
if ((nuint)src.Length > entry.byteSize)
throw new ArgumentException($"UploadRawInto: source ({src.Length} bytes) exceeds destination capacity ({entry.byteSize} bytes).");
fixed (byte* s = src)
UploadViaStaging(entry.devPtr, s, (nuint)src.Length);
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If src is empty, calling fixed (byte* s = src) will result in s being null. Depending on how UploadViaStaging handles null pointers and zero lengths, this could lead to undefined behavior or unnecessary overhead. Adding an early return when src is empty is a safer and more defensive approach.

    public void UploadRawInto(Tensor dst, ReadOnlySpan<byte> src)
    {
        if (src.IsEmpty)
            return;
        if (!_devPtrs.TryGetValue(dst.Handle, out var entry))
            throw new InvalidOperationException($"UploadRawInto: handle {dst.Handle} not registered.");
        if ((nuint)src.Length > entry.byteSize)
            throw new ArgumentException($"UploadRawInto: source ({src.Length} bytes) exceeds destination capacity ({entry.byteSize} bytes).");
        fixed (byte* s = src)
            UploadViaStaging(entry.devPtr, s, (nuint)src.Length);
    }

Comment on lines +2781 to +2788
int bpr = _pleBytesPerRow;
var qhost = _bpPleQuantHost!;
System.Threading.Tasks.Parallel.For(0, N, i =>
{
byte* rowPtr = basePtr + (long)tokens[i] * bpr;
new ReadOnlySpan<byte>(rowPtr, bpr)
.CopyTo(new Span<byte>(qhost).Slice(i * bpr, bpr));
});

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Creating Span and ReadOnlySpan instances and slicing them inside a hot Parallel.For loop introduces unnecessary overhead. Since Parallel.For is a blocking call, we can pin the qhost array once outside the loop using a fixed statement and perform a direct memory copy using Buffer.MemoryCopy. This avoids span allocation/slicing overhead and is significantly faster.

            int bpr = _pleBytesPerRow;
            var qhost = _bpPleQuantHost!;
            fixed (byte* qhostPtr = qhost)
            {
                System.Threading.Tasks.Parallel.For(0, N, i =>
                {
                    byte* rowPtr = basePtr + (long)tokens[i] * bpr;
                    Buffer.MemoryCopy(rowPtr, qhostPtr + (long)i * bpr, bpr, bpr);
                });
            }

Comment on lines +30 to +31
private static ushort HalfToUshort(Half h) =>
BitConverter.ToUInt16(BitConverter.GetBytes(h), 0);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using BitConverter.GetBytes(h) on a Half is highly inefficient because it allocates a new byte[] array on the heap on every call. Since this method is called thousands of times in a loop during testing, it creates significant GC pressure. Using BitConverter.HalfToUInt16Bits(h) completely avoids allocations and is extremely fast.

    private static ushort HalfToUshort(Half h) =>
        BitConverter.HalfToUInt16Bits(h);

@pekkah

pekkah commented Jun 15, 2026

Copy link
Copy Markdown
Owner Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements GPU-side dequantization for Gemma-4 PLE pre-pass rows (supporting Q8_0 and Q6_K formats), moving the work off the CPU and reducing host-to-device upload sizes by 4x. It also adds comprehensive unit tests to ensure bit-identical results with the CPU fallback. The feedback suggests adding a defensive guard for empty token lists to prevent an ArgumentOutOfRangeException, and hoisting Span creations out of the Parallel.For loops in both the GPU and CPU fallback paths to optimize performance.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines 2768 to 2771
private void BuildPerLayerProjectionsBatched(IReadOnlyList<int> tokens)
{
int N = tokens.Count, L = _hp.NumLayers;
int stackedDim = L * _pleWidth;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a defensive guard to handle the case where tokens is empty (N == 0). If N is 0, executing the GPU dequantization path will call DequantRowsQ8_0 or DequantRowsQ6K with nRows = 0, which throws an ArgumentOutOfRangeException.

    private void BuildPerLayerProjectionsBatched(IReadOnlyList<int> tokens)
    {
        int N = tokens.Count;
        if (N == 0) return;
        int L = _hp.NumLayers;
        int stackedDim = L * _pleWidth;

Comment on lines +2781 to +2788
int bpr = _pleBytesPerRow;
var qhost = _bpPleQuantHost!;
System.Threading.Tasks.Parallel.For(0, N, i =>
{
byte* rowPtr = basePtr + (long)tokens[i] * bpr;
new ReadOnlySpan<byte>(rowPtr, bpr)
.CopyTo(new Span<byte>(qhost).Slice(i * bpr, bpr));
});

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid repeatedly creating a new Span<byte> wrapper around qhost on every iteration of the parallel loop, you can hoist the span creation outside of the Parallel.For loop.

            int bpr = _pleBytesPerRow;
            var qhost = _bpPleQuantHost!;
            Span<byte> qhostSpan = qhost;
            System.Threading.Tasks.Parallel.For(0, N, i =>
            {
                byte* rowPtr = basePtr + (long)tokens[i] * bpr;
                new ReadOnlySpan<byte>(rowPtr, bpr)
                    .CopyTo(qhostSpan.Slice(i * bpr, bpr));
            });

Comment on lines +2799 to +2810
int bytesPerRow = _pleBytesPerRow;
var host = _bpPleRowHostAll!;
var dtype = _pleDType;
System.Threading.Tasks.Parallel.For(0, N, i =>
{
byte* rowPtr = basePtr + (long)tokens[i] * bytesPerRow;
var dst = new Span<float>(host).Slice(i * stackedDim, stackedDim);
if (dtype == DType.Float32)
new ReadOnlySpan<float>((float*)rowPtr, stackedDim).CopyTo(dst);
else
Dequantize.ToFloat32(new ReadOnlySpan<byte>(rowPtr, bytesPerRow), dst, dtype, stackedDim);
});

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similarly, in the CPU fallback path, you can hoist the creation of the Span<float> wrapper around host outside of the Parallel.For loop to avoid redundant allocations on each iteration.

            int bytesPerRow = _pleBytesPerRow;
            var host = _bpPleRowHostAll!;
            Span<float> hostSpan = host;
            var dtype = _pleDType;
            System.Threading.Tasks.Parallel.For(0, N, i =>
            {
                byte* rowPtr = basePtr + (long)tokens[i] * bytesPerRow;
                var dst = hostSpan.Slice(i * stackedDim, stackedDim);
                if (dtype == DType.Float32)
                    new ReadOnlySpan<float>((float*)rowPtr, stackedDim).CopyTo(dst);
                else
                    Dequantize.ToFloat32(new ReadOnlySpan<byte>(rowPtr, bytesPerRow), dst, dtype, stackedDim);
            });

pekkah and others added 2 commits June 15, 2026 13:59
…her bound

Review-cycle follow-ups to #256:
- BuildQ8_0Rows now spans the full int8 range [-128,127] (was [-127,127]),
  so the bit-identity assertion exercises the -128 quant too.
- Document that N*bpr stays within int at the 4096 chunk cap.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…in test

Address Gemini review on #256:
- UploadRawInto returns early on an empty span (after validation) to avoid
  fixed → null-ptr.
- Test HalfToUshort uses BitConverter.HalfToUInt16Bits (no per-call byte[]).

Not applied: the Parallel.For span/pointer hoisting suggestions don't compile
(Span<byte> is a ref struct, can't be captured in the lambda; new Span(array)
is already allocation-free), and the N==0 guard targets an unreachable case
(empty prompts are never prefilled).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@pekkah

pekkah commented Jun 15, 2026

Copy link
Copy Markdown
Owner Author

Thanks for the review. Disposition of the findings:

Applied (commit 8af6870):

  • UploadRawInto — early-return on an empty span (kept after the handle/capacity validation) to avoid fixed → null-ptr.
  • Test HalfToUshort — switched to BitConverter.HalfToUInt16Bits (no per-call byte[]).

Also covered the full int8 range incl. -128 in the test and documented the int-safe N*bpr bound at the 4096 chunk cap (commit edd5cae).

Not applied (with rationale):

  • Hoisting the Span/pinned-pointer out of the Parallel.For — these don't compile: Span<byte> is a ref struct and can't be captured in a lambda (CS8175), and a fixed-pinned pointer can't either (CS1686). The current new Span<byte>(array).Slice(...) inside the loop is allocation-free (it's a stack struct over the existing array ref), so there's no heap/GC overhead to remove.
  • N == 0 guard in BuildPerLayerProjectionsBatched — unreachable: a prompt is never empty, and the surrounding PrefillBatchedTrunk (embed all N, final norm on token N-1) doesn't support N=0 either, so a local guard would be false comfort rather than making N=0 work.

@pekkah pekkah merged commit 30fca96 into master Jun 15, 2026
1 check passed
@pekkah pekkah deleted the claude/gemma4-prefill-optimization-h7syay branch June 15, 2026 11:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant