Skip to content

perf(gemma4-cuda): MMQ + flash-attention prefill + dp4a decode (#141/#142)#143

Merged
pekkah merged 12 commits into
masterfrom
perf/gemma4-cuda-prefill-gemm-decode-dp4a-141-142
Jun 6, 2026
Merged

perf(gemma4-cuda): MMQ + flash-attention prefill + dp4a decode (#141/#142)#143
pekkah merged 12 commits into
masterfrom
perf/gemma4-cuda-prefill-gemm-decode-dp4a-141-142

Conversation

@pekkah

@pekkah pekkah commented Jun 5, 2026

Copy link
Copy Markdown
Owner

Summary

Closes the Gemma 4 E4B CUDA performance gap to llama.cpp on prefill (issues #141 / #142), on an RTX 4070 Ti with gemma-4-E4B-it-Q8_0.

before this branch after llama.cpp (same box)
Prefill (~1K ctx) 109 t/s 2853 (26×) ~8000
Prefill (~1.8K ctx) ~2920
Decode @d0 ~49 59 ~78

The two foundational commits (#141 cuBLAS-GEMM prefill + #142 dp4a/graphs decode) landed first (109→1564 prefill, 49→59 decode). The rest of this branch then closed most of the remaining prefill gap after profiling revealed the real bottleneck.

The key finding

The issue framed prefill as matmul-bound, but profiling at a realistic prompt length (N=1848) showed attention was 67% of prefill — it's O(n²) and the old "~30%" note was a short-prompt artifact — and memory-bound on redundant K/V reads (the scalar per-query kernel re-streamed each query's K/V window up to ~512× for SWA layers; ~79× off compute peak). So the lever was attention, not the matmul.

What's in the branch

Prefill (#141):

  • int8 tensor-core MMQ matmul llm_mmq_q8_0 — inline-PTX mma.sync.m16n8k32.s8, Q8_0 weights × Q8_1 activations, register-prefetch double-buffered. Reads each weight once as int8 with no fp16 HBM dequant temp; beats cuBLAS GEMM (matmul 332→316ms). Default-on (SHARPI_PREFILL_MMQ=0 reverts).
  • flash-attention prefill llm_flash_attn_prefill_f32 — shared K/V tiles reused across a 16-query tile + online softmax (FA2 style), half2 fp16x2 QK dot via inline PTX (K staged fp16, V fp32 for exact PV). Attention 929→245ms (3.8×). Default-on (SHARPI_PREFILL_FLASH=0 reverts).
  • batched Q8_0 embedding llm_embed_lookup_q8_0_batched — collapses 2·N host launches into one (embed 62→17ms).
  • cuBLAS-GEMM prefill foundation + F32 PLE via TF32 Sgemm.

Decode (#142): dp4a/Q8_1 int8 matvec + CUDA graphs default-on.

Net: Gemma 4 prefill ~1389 → 2920 t/s (2.05×) this session; 109 → 2853 across the whole branch. Decode untouched this session.

Correctness

MMQ and flash are argmax-stable, not bit-exact (int8 / fp16 rounding) — by design, consistent with the existing cuBLAS-GEMM prefill posture. The bit-exact parity oracles deliberately pin them off; new tolerance oracles cover the fast paths:

  • CudaMmqQ8_0Tests (int8 dot vs fp32 DotQ8_0)
  • CudaFlashAttnTests (vs both scalar attention kernels: GQA, head_dim 256/512, windowing, partial tiles)
  • Gemma4_E4B_BatchedPrefill_{Mmq,FlashAttn}MatchesSequential (e2e argmax-stability)

22 focused CUDA tests green; the full CUDA suite (120) was green before the final additive (batched-embed) commit. Reviewed with pr-review-toolkit (code-reviewer + silent-failure-hunter): no critical/important findings.

Remaining gap

Prefill is now well-balanced (matmul 45% / attention 36% / embed+ple 19%); closing the rest needs a full tensor-core flash at d=512 (occupancy-limited even in llama.cpp) and decode-side work (dp4a qs-alignment).

🤖 Generated with Claude Code

pekkah and others added 12 commits June 5, 2026 19:50
…ec (#142)

Close the bulk of the Gemma 4 GPU gap vs llama.cpp on the RTX 4070 Ti
(gemma-4-E4B-it-Q8_0):

  prefill ~1K ctx: 111 -> 1575 t/s (14.2x; llama.cpp ~8000)
  decode  ~1K ctx:  43 -> 50.6 t/s (dp4a + CUDA graphs; llama.cpp 77.7)
  decode  low ctx:  27 -> 58   t/s

#141 prefill — route the batched-trunk matmuls through a compute-bound
cuBLAS GEMM instead of the memory-bound matvec GEMM-N (which re-streamed
each weight per token):
  - llm_dequant_q8_0_to_f16 + llm_f32_to_f16 kernels feed cublasGemmEx
    (fp16xfp16 -> fp32, fp32 accum); each Q8_0 weight is dequantized once
    per batch. CudaBackend.MatMulBatchedGemm + fp16 scratch (grow-only).
  - F32 PLE projections go through Sgemm's TF32 path (was matvec GEMM-N,
    re-streaming a 110 MB weight per token -> 106ms; now ~6ms).
  - PLE host dequant parallelized across tokens.
  - PrefillGemmEnabled (SHARPI_PREFILL_GEMM) gates it; argmax-stable vs the
    fp32 sequential reference, not bit-exact (fp16 rounding).

#142 decode — profiling showed decode is matvec-bound (FFN+QKV+O = 71%),
not attention-bound (6.4%), so the win is a tighter Q8_0 matvec, not flash
attention:
  - llm_matvec_q8_0_dp4a: quantize the activation to Q8_1 and use __dp4a
    int8 dot products (llama.cpp's MMVQ approach), replacing the per-element
    int8->float decode. Q80Dp4aEnabled (SHARPI_Q80_DP4A), default on.
  - EnsureQ81Scratch pre-grows the Q8_1 buffer before the first CUDA-graph
    capture (capture forbids cudaMalloc).

Tests: Gemma4 CUDA suite green (batched-prefill oracle moved to a tolerance
model + a dp4a-off bit-exact case; CudaQ8_0/CudaMatMulBatched bit-exact
oracles pin Q80Dp4aEnabled=false and a new dp4a tolerance test added).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…de (#142)

After #137's launch fusions and the dp4a decode matvec, CUDA-graph
capture/replay measures +9-10% on the all-GPU Gemma 4 decode at BOTH low
and ~1K context (49.9->was 46, 58.2->was 53 t/s) — the short-context
regression that kept #136 default-off is gone. Flip _useCudaGraph to default
on (SHARPI_CUDA_GRAPH=0 reverts). Scoped to CudaForwardPass; the hybrid path
stays env-gated. Bit-parity already proven by Gemma4CudaGraphParityTests.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…, decode 49→59)

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…-k 64/top-p 0.95, --no-thinking)

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
… model)

Gemma 4 E4B-it's chat template renders a <|channel>thought block but the
model wasn't trained to fill it, so the default enable_thinking=true made it
degenerate into tag-soup. Detect arch == "gemma4" and default thinking off
(prints a one-line notice); --no-thinking still forces off for every other
model and the effective state flows to the temp-0 warning, prompt formatting,
and the MTP gate via s_noThinking. Non-Gemma models are unchanged.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Adds llm_mmq_q8_0: a shared-tiled m16n8k32 s8 mma matmul that multiplies
Q8_0 weights by Q8_1-quantized activations directly on int8 tensor cores,
reading each weight once as int8 with no fp16 HBM dequant temp (the cost
that capped MatMulBatchedGemm). 64x64 output tile, 8 warps x 4 mma/K-block,
per-block fp32 scale accumulation, no sum/bias term (Q8_0 symmetric, D4).

Wired into the Gemma 4 batched-trunk prefill behind PrefillMmqEnabled /
SHARPI_PREFILL_MMQ (default off during bring-up; takes precedence over the
GEMM path when on). New CudaBackend.MatMulBatchedMmq + sharpi_uint_at
misaligned-word helper. Parity validated in CudaMmqQ8_0Tests (int8 dot is
exact; tracks the fp32 DotQ8_0 reference to the dp4a per-row-RMS tolerance).

Profiling note: at realistic prompt lengths (N=1848) the prefill matmul is
only ~24% of the time; the scalar O(n^2) attention is ~67% (909ms). The MMQ
foundation lands here; matmul-kernel tuning + flash-attention prefill follow.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Register-prefetch double-buffering for llm_mmq_q8_0: the next K-tile's global
loads are issued into registers while the current tile's mma's run, so global
latency hides behind compute instead of stalling the per-K-block barrier
(cp.async is unusable — the Q8_0 qs is only 2-byte aligned). Widened the tile
to 64x128 (halves the nTok/BN weight re-read factor). Over a 1848-tok Gemma 4
prefill the matmul portion drops ~332ms (cuBLAS GEMM) -> ~316ms (MMQ), so MMQ
now beats the dequant->fp16->cuBLAS path and drops its fp16 weight HBM temp.

PrefillMmqEnabled now defaults on (SHARPI_PREFILL_MMQ=0 reverts), gated under
PrefillGemmEnabled. The GEMM-path oracle pins MMQ off to stay isolated; a new
Gemma4_E4B_BatchedPrefill_MmqMatchesSequential asserts MMQ argmax-stability.

End-to-end prefill is ~unchanged (1389 t/s): the 16ms matmul win is ~1% of the
attention-dominated total. Profiling at N=1848 shows the scalar O(n^2)
llm_full_seq_attention is ~67% (909ms) of prefill — the next target.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Replaces the scalar llm_full_seq_attention / llm_attention_swa_batched (one
256-thread block PER query, each re-reading its whole K/V range from global —
O(n^2); for SWA, adjacent queries' 512-wide windows overlap ~99% so K/V is
re-streamed up to ~512x) with llm_flash_attn_prefill_f32: a block handles a
tile of 8 queries of one head and streams K/V through shared-memory tiles with
an online softmax (FlashAttention-2 style — running max/sum + rescaled output
accumulator, no n^2 score buffer). Each key is read from global once per 8
queries instead of once per query. One warp per query; lane L owns head dims
{L,L+32,..}; GQA, causal + optional sliding window, per-layer head_dim (256
SWA / 512 global). fp32 KV. Argmax-stable (online softmax reassociates the
same sum), not bit-exact.

Profiling at N=1848 showed attention was the dominant prefill cost and
memory-bound on redundant K/V reads (~79x off compute peak), not compute-bound.
Result: attention ~929ms -> ~411ms (2.26x), Gemma 4 prefill ~1389 -> ~2180 t/s
(1.57x). Decode unchanged (flash is prefill-only).

CudaBackend.FlashAttentionPrefill (dynamic shared sized to a 48KB budget),
PrefillFlashAttnEnabled / SHARPI_PREFILL_FLASH (default on). Parity validated
in CudaFlashAttnTests (vs both scalar kernels: GQA, hd 256/512, windowing,
partial tiles) + e2e Gemma4_E4B_BatchedPrefill_FlashAttnMatchesSequential. The
matvec/GEMM/MMQ bit-exact oracles pin flash off (it isn't bit-exact). 19
focused tests green.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
FA_QT 8 -> 16 (16 warps / 512 threads per block): doubles the K/V reuse factor
(each key read from global once per 16 queries). Attention ~411ms -> ~382ms at
N=1848; Gemma 4 prefill ~2180 -> ~2364 t/s. Diminishing return confirms the
residual attention cost is now the scalar warp-reduce QK/PV compute, not K/V
traffic — tensor-core QK/PV is the next lever. Parity unchanged.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The residual flash-attention cost was the scalar fp32 QK warp-reduce. Stage K as
fp16 (half2-packed) in shared and run the QK dot with fma.rn.f16x2 (2 multiply-
accumulates/instruction, ~2x the fp32 FMA rate); V stays fp32 for the exact scalar
PV. Switched the per-lane dim ownership to adjacent PAIRS (pi = lane+32*p, dims
2*pi/2*pi+1) so the shared half2 K loads stay coalesced. New sharpi_f32x2_to_f16x2
/ sharpi_hfma2 / sharpi_f16x2_sum inline-PTX helpers (NVRTC has no cuda_fp16.h).
fp16 K also halves the K shared footprint → larger key tiles, fewer barriers.

K rounded to fp16 (scores tolerate it — argmax-stable, maxAbs ~1.4e-4 vs the fp32
scalar reference); V/PV/softmax stay fp32. CudaFlashAttnTests tolerance relaxed to
5e-3*rms for the fp16-K relative error.

Attention ~382ms -> ~245ms (1.56x); Gemma 4 prefill ~2364 -> ~2847 t/s. Session
total: prefill 1389 -> 2847 t/s (2.05x). 18 focused tests green.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The batched-trunk prefill embedded tokens with a per-token EmbedLookupQ8_0 + a
device copy — 2*N host-driven launches (~62ms at N=1848, launch-overhead bound now
that attention is fast). Collapse it to a single grid.x=N launch: new
llm_embed_lookup_q8_0_batched reads token_ids[blockIdx.x] from a device buffer and
writes row blockIdx.x of _bpHidden directly (per-token body identical → bit-exact).
CudaBackend.EmbedLookupQ8_0Batched; the prefill uploads the ids once and uses it for
the Q8_0 embed table (Gemma 4), keeping the per-token loop for other dtypes.

Embed ~62ms -> ~17ms; Gemma 4 prefill ~2847 -> ~2920 t/s. Session total: prefill
1389 -> 2920 t/s (2.10x). 22 focused CUDA tests green.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…hed embed)

Updates the Gemma 4 E4B CUDA row for the #141 prefill work (int8 tensor-core MMQ,
flash-attention prefill, batched embed; ~1.8× at ~1K ctx, ~2.05× at 1.8K). Also
fixes a stale kernel name (llm_mmq_q8_0_naive → llm_mmq_q8_0) in CudaMmqQ8_0Tests
doc (pr-review-toolkit finding).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant performance optimizations for Gemma 4 prefill and decode paths, including a memory-efficient flash-attention prefill kernel, an int8 tensor-core MMQ matmul, a dp4a-based Q8_0 decode matvec, and batched embedding lookups. The code review feedback suggests several high-impact optimizations: replacing runtime integer division and modulo operations with bitwise shifts and masks in the flash-attention kernel, consolidating dual 8-bit loads into single 16-bit loads for scale factors across multiple CUDA kernels, fixing a late overflow check in MatMulBatchedMmq, and utilizing CollectionsMarshal.AsSpan to avoid allocations when processing token lists.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +3137 to +3145
for (int idx = tid; idx < kt_tile * hd2; idx += (int)blockDim.x) {
int kk = idx / hd2, pr = idx - kk * hd2;
unsigned int kh = 0u;
if (kk < tile_keys) {
long off = (long)(kt0 + kk) * kv_dim + (long)kv_head * head_dim + 2 * pr;
kh = sharpi_f32x2_to_f16x2(k_cache[off], k_cache[off + 1]);
}
sKh[idx] = kh;
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since head_dim (and thus hd2) is always a power of two in LLMs (e.g., 256, 512), we can completely avoid the expensive runtime integer division and modulo operations by using bitwise shifts and masks. This significantly improves the performance of the key loading phase in the flash attention prefill kernel.

        int shift_k = __ffs(hd2) - 1;
        int mask_k = hd2 - 1;
        for (int idx = tid; idx < kt_tile * hd2; idx += (int)blockDim.x) {
            int kk = idx >> shift_k;
            int pr = idx & mask_k;
            unsigned int kh = 0u;
            if (kk < tile_keys) {
                long off = (long)(kt0 + kk) * kv_dim + (long)kv_head * head_dim + 2 * pr;
                kh = sharpi_f32x2_to_f16x2(k_cache[off], k_cache[off + 1]);
            }
            sKh[idx] = kh;
        }

Comment on lines +3147 to +3152
for (int idx = tid; idx < kt_tile * head_dim; idx += (int)blockDim.x) {
int kk = idx / head_dim, d = idx - kk * head_dim;
sV[idx] = (kk < tile_keys)
? v_cache[(long)(kt0 + kk) * kv_dim + (long)kv_head * head_dim + d]
: 0.f;
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similarly to the key loading phase, head_dim is always a power of two, so we can replace the runtime integer division and modulo with bitwise operations to optimize the value loading phase.

        int shift_v = __ffs(head_dim) - 1;
        int mask_v = head_dim - 1;
        for (int idx = tid; idx < kt_tile * head_dim; idx += (int)blockDim.x) {
            int kk = idx >> shift_v;
            int d = idx & mask_v;
            sV[idx] = (kk < tile_keys)
                ? v_cache[(long)(kt0 + kk) * kv_dim + (long)kv_head * head_dim + d]
                : 0.f;
        }

Comment on lines +1342 to +1344
unsigned int dlo = sharpi_byte_at(weights, b0);
unsigned int dhi = sharpi_byte_at(weights, b0 + 1);
float dw = sharpi_fp16_to_fp32(dlo | (dhi << 8));

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Loading the 16-bit fp16 scale d using two separate sharpi_byte_at calls is inefficient. Since b0 is always 2-byte aligned, we can perform a single 16-bit load by casting the pointer to const unsigned short*, which compiles to a single ld.global.u16 instruction.

            unsigned int d_bits = *reinterpret_cast<const unsigned short*>(reinterpret_cast<const char*>(weights) + b0);
            float dw = sharpi_fp16_to_fp32(d_bits);

Comment on lines +1455 to +1456
if (tid < MMQ_BM) { long wb = ((long)(row_block + tid) * nb + (KB)) * 34L; \
rWd = (row_block + tid < rows) ? sharpi_fp16_to_fp32(sharpi_byte_at(weights, wb) | (sharpi_byte_at(weights, wb + 1) << 8)) : 0.f; } \

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

We can optimize the scale loading inside the MMQ_LOAD_TILE macro by performing a single 16-bit load instead of two sharpi_byte_at calls.

        if (tid < MMQ_BM) { long wb = ((long)(row_block + tid) * nb + (KB)) * 34L; \
            rWd = (row_block + tid < rows) ? sharpi_fp16_to_fp32(*reinterpret_cast<const unsigned short*>(reinterpret_cast<const char*>(weights) + wb)) : 0.f; } \

Comment on lines +1592 to +1594
unsigned int dlo = sharpi_byte_at(weights, b0 + 0);
unsigned int dhi = sharpi_byte_at(weights, b0 + 1);
float d = sharpi_fp16_to_fp32(dlo | (dhi << 8));

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

We can optimize the scale loading inside llm_dequant_q8_0_to_f16 by performing a single 16-bit load instead of two sharpi_byte_at calls.

        unsigned int d_bits = *reinterpret_cast<const unsigned short*>(reinterpret_cast<const char*>(weights) + b0);
        float d = sharpi_fp16_to_fp32(d_bits);

Comment on lines +1824 to +1827
int qN = (int)((long)cols * nTok);
if ((long)cols * nTok > int.MaxValue)
throw new InvalidOperationException(
$"MatMulBatchedMmq: cols*nTok ({(long)cols * nTok}) exceeds int range.");

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The overflow check (long)cols * nTok > int.MaxValue is performed after casting the value to int on line 1824. If the value exceeds int.MaxValue, qN will already contain an overflowed/truncated value before the check runs. Performing the check on the long value first is safer and cleaner.

            long totalElements = (long)cols * nTok;
            if (totalElements > int.MaxValue)
                throw new InvalidOperationException(
                    $"MatMulBatchedMmq: cols*nTok ({totalElements}) exceeds int range.");
            int qN = (int)totalElements;

Comment on lines +1823 to +1826
int[] ids = tokens as int[] ?? System.Linq.Enumerable.ToArray(tokens);
var idTensor = _gpu.UploadRaw(
System.Runtime.InteropServices.MemoryMarshal.AsBytes<int>(ids),
TensorShape.D1(N), DType.Float32);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid array allocations when tokens is a List<int>, we can use CollectionsMarshal.AsSpan to get a direct span of the list's underlying array. This completely avoids allocations for both int[] and List<int> collections.

            ReadOnlySpan<int> idsSpan = tokens is int[] arr ? arr :
                                        tokens is System.Collections.Generic.List<int> list ? System.Runtime.InteropServices.CollectionsMarshal.AsSpan(list) :
                                        System.Linq.Enumerable.ToArray(tokens);
            var idTensor = _gpu.UploadRaw(
                System.Runtime.InteropServices.MemoryMarshal.AsBytes<int>(idsSpan),
                TensorShape.D1(N), DType.Float32);

@pekkah

pekkah commented Jun 6, 2026

Copy link
Copy Markdown
Owner Author

Reviewed the Gemini code-assist suggestions — all 7 are legitimate medium-priority micro-opts. Deferred to #145 rather than amending a CI-green, tested, review-toolkit-clean PR at merge time (the two shift/mask suggestions also bake a power-of-two head_dim assumption into an otherwise-general kernel — needs a guard first). Remaining perf gap tracked in #141/#142; full tensor-core flash at d=512 in #146.

@pekkah pekkah merged commit 1c70c27 into master Jun 6, 2026
1 check passed
@pekkah pekkah deleted the perf/gemma4-cuda-prefill-gemm-decode-dp4a-141-142 branch June 6, 2026 06:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant