Skip to content

perf(gemma4-cuda): TC flash-attention prefill (#146/#147) + SoA Q8_0 weight layout (#149) — prefill 2853→4240#148

Merged
pekkah merged 10 commits into
masterfrom
perf/gemma4-cuda-flash-tc-146
Jun 6, 2026
Merged

perf(gemma4-cuda): TC flash-attention prefill (#146/#147) + SoA Q8_0 weight layout (#149) — prefill 2853→4240#148
pekkah merged 10 commits into
masterfrom
perf/gemma4-cuda-flash-tc-146

Conversation

@pekkah

@pekkah pekkah commented Jun 6, 2026

Copy link
Copy Markdown
Owner

Summary

Two stacked, default-on Gemma 4 CUDA prefill wins, each built bottom-up and validated. Prefill 2853 → 4240 t/s (1.8K ctx) = 34% → 50% of llama.cpp's ~8475. All fast paths argmax-stable; the SoA repack is bit-identical. Full ForwardPass suite 382/382 green.

1. Tensor-core flash-attention prefill (#146/#147)

Both QK^T and P·V on the mma cores (mma.sync.m16n8k16.f32.f16.f16.f32), replacing the scalar O(n²) per-query attention.

2. SoA Q8_0 weight layout (#149)

Repacks 2-D Q8_0 weights at upload into [quants rows*cols B][scales rows*nb fp16] so all readers load the quants 16-byte aligned instead of the interleaved 34-byte block's 2-byte-misalignment __funnelshift.

  • One-time GPU repack (llm_q8_0_repack_soa); all five Q8_0 readers (MMQ, dp4a, fp32 matvec, GEMM-N, dequant) get bit-identical SoA variants and auto-route per repacked handle.
  • Bit-identical end-to-end (CudaMmqSoaE2ETests: prefill + 4 decode, maxAbs 0) — the bit-exact oracles exercise the SoA path (converting every reader is what keeps them meaningful; a half-migration would false-pass on garbage).
  • +10-12% prefill by default (N=965 3297→3698, N=2054 3849→4240). SHARPI_MMQ_SOA=0 reverts.

Dead-ends ruled out (documented, reverted)

  • split-K MMQ: occupancy isn't the lever at real prefill batch sizes (probe mismeasured at nTok=1024).
  • cp.async-pipelined MMQ: bit-identical but slower than register-prefetch on Ada (barriers add overhead; .ca beats .cg but both lose).

Diagnostics kept

CudaMmqRooflineProbe (int8 TOPS vs peak at FFN shapes) — but probes must use the real prefill nTok.

Test plan

  • Focused: dotnet test tests/SharpInference.Tests.ForwardPass --filter "Gemma4Cuda|CudaFlashAttn|CudaMmq" (26 green).
  • Full ForwardPass suite: 382/382, 0 failed.

🤖 Generated with Claude Code

pekkah and others added 2 commits June 6, 2026 09:48
…ve (#146)

Single-warp NVRTC test kernel + host wrapper + unit test proving the
mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 fragment layouts (A 16x16
row-major, B 16x8 col-major, C 16x8 fp32) against a CPU fp16-rounded reference.
maxAbs 4.77e-7, 0/128 mismatches — the A/B/C lane->register maps are correct.

This de-risks the hardest unknown for the full TC flash-attention prefill: a
wrong fragment map silently produces garbage. Reusable building block for the
QK^T and P.V mma stages.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Full TC version of the half2 flash kernel: both QK^T and P.V run on the mma
cores (mma.sync m16n8k16, fp16 multiplicands / fp32 accumulate), one warp per
16-query tile, online softmax. Key design points:

- The QK^T score C-fragment is reused directly as the P.V A-fragment with no
  transpose — the key index is QK^T's N-col and P.V's contraction dim, and the
  m16n8k16 C and A fragment layouts coincide on (row, 2*tig).
- O[16 x head_dim] is too large for registers (256 regs/lane at d=512), so it
  lives in shared fp32 and is rescaled in place per key-tile; K and V time-share
  one 16 x head_dim fp16 buffer, fitting the whole kernel in 48 KB shared at d=512.
- Guards the masked-tile online-softmax NaN (a later query's early tiles fall
  entirely outside its sliding window -> mnew=-inf -> exp(-inf+inf)).

Parity (CudaFlashAttnTcTests): maxAbs ~3-4e-4 vs the scalar batched kernels across
GQA, both Gemma 4 head_dims (256 SWA / 512 global), windowing, partial tiles —
more accurate than the half2 kernel. End-to-end FlashTcMatchesSequential green.

A/B vs the half2 baseline (all-GPU, warm, gemma-4-E4B Q8_0, 4070 Ti):
  N=965:  2585 -> 2732 t/s (+5.7%)
  N=2054: 2716 -> 2849 t/s (+4.9%)
The win is occupancy-limited (single warp/block + 48 KB shared -> ~2 warps/SM);
a multi-warp / d-split version is the path to a larger gain. Opt-in for now
(SHARPI_PREFILL_FLASH_TC=1, gated on head_dim % 16 == 0) until that lands.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a tensor-core flash-attention prefill implementation (Issue #146) utilizing mma.sync.aligned.m16n8k16 instructions. It adds the CUDA kernels llm_mma_test_m16n8k16_f32 and llm_flash_attn_prefill_tc, exposes them via CudaBackend, integrates the tensor-core prefill option into CudaForwardPass (controlled by the SHARPI_PREFILL_FLASH_TC environment variable), and includes benchmarking scripts and comprehensive unit tests to validate correctness against scalar references. No review comments were provided, so there is no additional feedback to address.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

pekkah and others added 2 commits June 6, 2026 10:27
…147)

Fixes the single-warp #146 occupancy limit. A block is now W=4 warps cooperating
on one 16-query tile with the head dim split across them: warp w owns output
columns [w*dW, ...) (dW = head_dim/W), so O[16 x dW] is REGISTER-resident (64
regs/lane at d=512) instead of in shared — no per-key-tile shared-O rescale, and
the freed shared lifts occupancy ~2 -> ~16-20 warps/SM. Each warp computes a
PARTIAL QK^T over its d-slice; the partials sum across warps through a small
shared S buffer ([W x 16 x 16] fp32), after which every warp holds the full
reduced score tile and proceeds like the single-warp kernel (no-transpose
score->P, P.V for its slice). Requires head_dim % 64 == 0 (W*16); the #146
single-warp kernel is the head_dim % 16 fallback.

Parity (CudaFlashAttnTcTests.TC2): maxAbs identical to the single-warp kernel and
the scalar reference (~3-4e-4) across GQA / 256-SWA & 512-global / windowing /
partial tiles — the d-split reduction is numerically equivalent.

A/B vs half2 (all-GPU, warm, gemma-4-E4B Q8_0, 4070 Ti):
  N=965:  2582 (half2) -> 2691 (tc1) -> 3269 (tc2)  = +26.6%
  N=2054: 2732 (half2) -> 2851 (tc1) -> 3818 (tc2)  = +39.8%  (win grows with ctx)

Now default on (SHARPI_PREFILL_FLASH_TC=0 reverts to half2; =1 + _TC1=1 forces the
single-warp kernel for A/B). Pinned off in the bit-exact / half2 Gemma4 oracles.
All 23 focused Gemma4Cuda/TC/MMQ/primitive tests green.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
… flash #146/#147

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@pekkah pekkah changed the title perf(gemma4-cuda): tensor-core flash-attention prefill (#146) perf(gemma4-cuda): tensor-core flash-attention prefill — single-warp + multi-warp d-split (#146/#147) Jun 6, 2026
@pekkah

pekkah commented Jun 6, 2026

Copy link
Copy Markdown
Owner Author

Update: multi-warp / d-split (#147) landed — default-on, +27-40%

The single-warp kernel above was occupancy-limited (+5%). The follow-up llm_flash_attn_prefill_tc2 fixes it: W=4 warps cooperate on each 16-query tile, head dim split across them → O is register-resident (no shared-O rescale), occupancy ~2 → ~16-20 warps/SM. Partial QK^T per warp reduced through a small shared S buffer; same no-transpose score→P trick downstream.

N half2 tc1 (#146) tc2 (#147) tc2 vs half2
965 2582 2691 3269 +26.6%
2054 2732 2851 3818 +39.8%

Parity identical to scalar (~3-4e-4). Now default-on (SHARPI_PREFILL_FLASH_TC=0 → half2). Pinned off in the bit-exact/half2 Gemma4 oracles; 23 focused tests green. README updated 2853→3269. Closes #146 and #147.

pekkah and others added 5 commits June 6, 2026 12:06
After the TC flash work moved the prefill bottleneck to the matmul/FFN GEMMs
(profiling: ~53% of prefill), this probe times MatMulBatchedMmq at FFN-shaped
GEMMs and reports achieved int8 TOPS vs the ~160 TOPS dense peak:
  ffn-gate/up [8192x2048]: 53.7 TOPS (34%)
  ffn-down    [2048x8192]: 36.4 TOPS (23%)  <- worst, occupancy-starved
  qkv         [6144x2048]: 52.2 TOPS (33%)
MMQ is at 23-34% of TC peak (~3x headroom), matching the 2.2x end-to-end gap to
llama.cpp. ffn-down (few output rows over a long contraction) is occupancy-bound
the same way the single-warp flash kernel was — split-K + better tiling is the
next attack. Diagnostic only (asserts true; run via --filter).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…149, isolated)

The interleaved Q8_0 block (34 B = 2 B fp16 scale + 32 int8) puts the qs at a
2-byte offset, so every weight word in the MMQ load costs an extra global load +
__funnelshift (sharpi_uint_at). llm_mmq_q8_0_soa reads the weights from a struct-
of-arrays repack instead — quants 16 B-aligned (8 uint/block), fp16 scales separate
— so each weight word is a plain aligned load. Only the weight path changes;
activations (Q8_1) are already aligned.

Validated in isolation (CudaMmqSoaTests):
- BIT-IDENTICAL to the interleaved MMQ (maxAbs 0, 0 diffs) across 5 shapes incl.
  ffn-down [2048x8192] and qkv.
- Speed at REAL prefill nTok=2048 (not the 1024 that mismeasured split-K):
    ffn-gate/up 54.6 -> 61.4 TOPS (+11.1%)
    ffn-down    39.6 -> 49.1 TOPS (+19.4%)
    qkv         54.6 -> 60.5 TOPS (+9.7%)

A per-instruction-efficiency win (not occupancy), so it should translate end-to-end.
Next: model-wide SoA migration (upload repack + the decode matvec / embed / dequant
Q8_0 readers, since each weight is shared by prefill MMQ and decode).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…prefill (#149)

Repacks 2-D Q8_0 GEMM weights into the SoA layout at upload (mmqSoa flag /
SHARPI_MMQ_SOA, default off), so the prefill MMQ and decode dp4a matvec read the
weight with plain aligned loads instead of the 2-byte-misalignment funnelshift.

- llm_q8_0_repack_soa: one-time GPU repack interleaved [34B/block] → [quants
  rows*cols B][scales rows*nb fp16]. CudaBackend.RepackQ8_0Soa allocates the dest,
  repacks, frees the source, and marks the handle in _soaHandles.
- llm_mmq_q8_0_soa (refactored to the AoS signature, scales located internally) and
  llm_matvec_q8_0_dp4a_soa: aligned-load variants. MatMulBatchedMmq and the decode
  matvec auto-route to them per repacked handle — no caller changes.
- UploadWeight repacks 2-D Q8_0 (norms/biases are 1-D; embedding uploads elsewhere);
  output.weight auto-routes via the handle set.

Verified BIT-IDENTICAL end-to-end (CudaMmqSoaE2ETests: Gemma 4 prefill + 4 decode
steps, maxAbs 0) and at the kernel level (CudaMmqSoaTests, +13-18% isolated GEMM).
A/B end-to-end (all-GPU, warm, N=2054): prefill 3849 → 4240 t/s (+10.2%), decode
neutral. 26 focused tests green (SoA off by default).

Opt-in until the remaining Q8_0 readers (dequant, fp32 matvec, gemm_n — used only by
the bit-exact/gemm-off oracles) are converted, then it can default on. Follow-up.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Completes the SoA Q8_0 migration so it can default on. Added SoA variants of the
three remaining interleaved readers — llm_matvec_q8_0 (fp32 decode), _gemm_n, and
llm_dequant_q8_0_to_f16 — each bit-identical to its AoS counterpart (aligned quant
bytes + one aligned fp16 scale per block, scales located at byte rows*cols). All
five Q8_0 readers (MMQ, dp4a, fp32 matvec, GEMM-N, dequant) now auto-route to the
SoA kernel per repacked handle; the embedding stays interleaved (uploaded separately,
not repacked).

mmqSoa now defaults ON (SHARPI_MMQ_SOA=0 reverts). With every reader SoA-aware the
bit-exact Gemma4 oracles exercise the SoA path and pass bit-exactly (a half-migration
would have let them falsely pass on garbage — they compare batched-vs-sequential, so
both sides must read the format correctly). 26 focused Gemma4/Cuda tests green incl.
GemmOff_MatchesSequentialBitExact.

End-to-end (all-GPU, warm, gemma-4-E4B Q8_0, 4070 Ti): prefill +10-12% now by default
(N=965 3297→3698, N=2054 3849→4240). Prefill ~4240 t/s = ~50% of llama.cpp's ~8475.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…A Q8_0 #149 default-on

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@pekkah pekkah changed the title perf(gemma4-cuda): tensor-core flash-attention prefill — single-warp + multi-warp d-split (#146/#147) perf(gemma4-cuda): TC flash-attention prefill (#146/#147) + SoA Q8_0 weight layout (#149) — prefill 2853→4240 Jun 6, 2026
…aHandles lifecycle

From the pr-review-toolkit cycle (code-reviewer, silent-failure-hunter, pr-test-analyzer,
comment-analyzer). No critical/high defects were found (all 5 Q8_0 readers are SoA-routed;
the one non-SoA reader, MatMulN2, throws on Q8_0 rather than reading garbage; NVRTC failures
surface loudly). Addressed the actionable items:

- Test gap (sev 8): the GPU repack kernel + the fp32-matvec / GEMM-N / dequant SoA readers
  had no GGUF-free coverage (only bench-machine model oracles). Added
  GpuRepack_AllSoaReaders_BitIdenticalToInterleaved: repacks on the GPU (production path) and
  asserts all five readers (dp4a, fp32, GEMM-N, dequant, MMQ) are bit-identical to interleaved
  across 3 shapes — runs on any CUDA box.
- _soaHandles lifecycle: switched HashSet→ConcurrentDictionary (matches the sibling handle
  tables, removes the latent thread-safety footgun the whole SoA-correctness invariant rested
  on) and prune it on Free + Dispose (was an unbounded leak across model load/free cycles).
- README: balanced a stray parenthesis in the Gemma 4 CUDA prefill cell.

28 focused Gemma4/Cuda tests green.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@pekkah

pekkah commented Jun 6, 2026

Copy link
Copy Markdown
Owner Author

pr-review-toolkit cycle (4 agents) — results

Ran code-reviewer, silent-failure-hunter, pr-test-analyzer, comment-analyzer over git diff master...HEAD.

No critical/high defects. Verified by the agents: all 5 Q8_0 readers are SoA-routed; the one non-SoA reader (MatMulN2) throws on Q8_0 rather than reading garbage; the embedding/MoE weights are never repacked (uploaded outside UploadWeight); NVRTC failures surface as load-time exceptions, not inference garbage; repack byte-math verified; kernel comments (fragment layouts, shared-mem formulas, guard conditions) are accurate and internally consistent.

Addressed (commit e028ea8):

  • Test gap (sev 8) — the GPU repack kernel + the fp32-matvec / GEMM-N / dequant SoA readers had no GGUF-free coverage (only bench-machine model oracles). Added GpuRepack_AllSoaReaders_BitIdenticalToInterleaved: repacks on the GPU (production path) and asserts all 5 readers bit-identical to interleaved across 3 shapes — runs on any CUDA box. ✅ all green.
  • _soaHandles lifecycle — the whole SoA-correctness invariant rested on this set. Switched HashSetConcurrentDictionary (matches the sibling handle tables, removes the latent thread-safety footgun) and prune it on Free/Dispose (was an unbounded leak across model load/free cycles).
  • README — balanced a stray parenthesis.

Noted, not addressed (lower-value test hardening, optional follow-ups):

  • TC flash kernels are only tested with startPos==0; continued-prefill (chat cache) can pass startPos>0 — worth one parity config.
  • No nTok < 16 (single partial tile) or genuinely TC1-only head_dim (%16 && !%64, e.g. 48/80) config; Gemma 4's 256/512 are both %64 so TC1 is only force-exercised via RunParity(tc2:false).
  • Comment perf/occupancy figures ("~2 warps/SM", "23-34% of TC peak") lack a "measured on RTX 4070 Ti" anchor and will rot.

These are hardening suggestions, not bugs — deferring.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant