Skip to content

perf(cpu,engine): dequant-once multi-input MoE dot kernels ("one GEMM per expert") — routed MoE is dequant-bound (#110 follow-up) #112

@pekkah

Description

@pekkah

Summary

Follow-up to #110. The batched routed-MoE in #110 groups prompt tokens by selected expert and reads each expert's weight rows once per chunk, dotting them against every token routing to that expert. This amortized the DRAM weight read to ~1 ms/tok (all 256 experts are hit once per 512-token chunk). But each weight row is still re-dequantized once per token-dot via DispatchDot/DispatchDotQ8K, so the routed MoE is now dequant/compute-bound at ~9.2 ms/tok (≈35% of prefill wall).

Root cause

BatchedRoutedExperts loops, per (used-expert, row), over the M tokens routing to that expert and calls DispatchDotQ8K(weightRow, tokenInputQ8KS, …) M times. The weight row stays hot in cache (read amortized), but the Q3_K nibble/scale unpack is repeated M times. For Carnice's Q3_K routed experts that unpack is a large fraction of each dot.

Proposed work

Add dequant-once, multi-input dot kernels so an expert weight row is unpacked once and dotted against all M of its tokens — the "one GEMM per expert" the original issue called for. The int-domain arithmetic must stay bit-identical to the single-input kernels (unpack the weight row to its int8 + per-sub-block scales once, then do M int8 dots — same accumulation, same scale-multiply order).

  • SimdKernels.DotQ3K_Q8KS_NIn(byte* row, byte** q8ksInputs, float* outputs, int n, int cols)
  • SimdKernels.DotQ8_0_Q8KS_NIn(...) (lower priority — Q8_0 has little to amortize)
  • Optionally Q4_K/Q5_K/Q6_K N-input variants for other models.
  • DotQ6K_Q8K_2In (SimdKernels.cs) is the existing 2-input precedent to generalize.

Then rewrite BatchedRoutedExperts Phase A (gate/up) and Phase C (down) to call the N-input kernels per weight row instead of looping DispatchDot.

Byte-parity (mandatory)

The A/B oracle BatchedPrefill_BitwiseMatchesSequential_Carnice (#110) compares batched-vs-sequential prefill bitwise — it will catch any reordering. Keep the per-token top-k reduce order (downPartial reduce) unchanged. See the Q8_KS validation note (#107) and the K/V MatVecDual parity regression note for the failure mode.

Expected impact

Routed MoE is ~35% of prefill wall; amortizing the Q3_K unpack across ~16 tokens/expert (512-chunk) could roughly halve the dequant cost. Best paired with #111 (trunk batching) since the trunk dominates once the MoE shrinks.

Scope (files)

  • src/SharpInference.Cpu/SimdKernels.cs (or SimdKernels.*.cs) — N-input dot kernels
  • src/SharpInference.Engine/CudaHybridGdnForwardPass.csBatchedRoutedExperts Phase A/C
  • tests/SharpInference.Tests.ForwardPass/ — unit tests for the N-input kernels vs the single-input reference

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions