Summary
Follow-up to #110. The batched routed-MoE in #110 groups prompt tokens by selected expert and reads each expert's weight rows once per chunk, dotting them against every token routing to that expert. This amortized the DRAM weight read to ~1 ms/tok (all 256 experts are hit once per 512-token chunk). But each weight row is still re-dequantized once per token-dot via DispatchDot/DispatchDotQ8K, so the routed MoE is now dequant/compute-bound at ~9.2 ms/tok (≈35% of prefill wall).
Root cause
BatchedRoutedExperts loops, per (used-expert, row), over the M tokens routing to that expert and calls DispatchDotQ8K(weightRow, tokenInputQ8KS, …) M times. The weight row stays hot in cache (read amortized), but the Q3_K nibble/scale unpack is repeated M times. For Carnice's Q3_K routed experts that unpack is a large fraction of each dot.
Proposed work
Add dequant-once, multi-input dot kernels so an expert weight row is unpacked once and dotted against all M of its tokens — the "one GEMM per expert" the original issue called for. The int-domain arithmetic must stay bit-identical to the single-input kernels (unpack the weight row to its int8 + per-sub-block scales once, then do M int8 dots — same accumulation, same scale-multiply order).
SimdKernels.DotQ3K_Q8KS_NIn(byte* row, byte** q8ksInputs, float* outputs, int n, int cols)
SimdKernels.DotQ8_0_Q8KS_NIn(...) (lower priority — Q8_0 has little to amortize)
- Optionally Q4_K/Q5_K/Q6_K N-input variants for other models.
DotQ6K_Q8K_2In (SimdKernels.cs) is the existing 2-input precedent to generalize.
Then rewrite BatchedRoutedExperts Phase A (gate/up) and Phase C (down) to call the N-input kernels per weight row instead of looping DispatchDot.
Byte-parity (mandatory)
The A/B oracle BatchedPrefill_BitwiseMatchesSequential_Carnice (#110) compares batched-vs-sequential prefill bitwise — it will catch any reordering. Keep the per-token top-k reduce order (downPartial reduce) unchanged. See the Q8_KS validation note (#107) and the K/V MatVecDual parity regression note for the failure mode.
Expected impact
Routed MoE is ~35% of prefill wall; amortizing the Q3_K unpack across ~16 tokens/expert (512-chunk) could roughly halve the dequant cost. Best paired with #111 (trunk batching) since the trunk dominates once the MoE shrinks.
Scope (files)
src/SharpInference.Cpu/SimdKernels.cs (or SimdKernels.*.cs) — N-input dot kernels
src/SharpInference.Engine/CudaHybridGdnForwardPass.cs — BatchedRoutedExperts Phase A/C
tests/SharpInference.Tests.ForwardPass/ — unit tests for the N-input kernels vs the single-input reference
Summary
Follow-up to #110. The batched routed-MoE in #110 groups prompt tokens by selected expert and reads each expert's weight rows once per chunk, dotting them against every token routing to that expert. This amortized the DRAM weight read to ~1 ms/tok (all 256 experts are hit once per 512-token chunk). But each weight row is still re-dequantized once per token-dot via
DispatchDot/DispatchDotQ8K, so the routed MoE is now dequant/compute-bound at ~9.2 ms/tok (≈35% of prefill wall).Root cause
BatchedRoutedExpertsloops, per (used-expert, row), over the M tokens routing to that expert and callsDispatchDotQ8K(weightRow, tokenInputQ8KS, …)M times. The weight row stays hot in cache (read amortized), but the Q3_K nibble/scale unpack is repeated M times. For Carnice's Q3_K routed experts that unpack is a large fraction of each dot.Proposed work
Add dequant-once, multi-input dot kernels so an expert weight row is unpacked once and dotted against all M of its tokens — the "one GEMM per expert" the original issue called for. The int-domain arithmetic must stay bit-identical to the single-input kernels (unpack the weight row to its int8 + per-sub-block scales once, then do M int8 dots — same accumulation, same scale-multiply order).
SimdKernels.DotQ3K_Q8KS_NIn(byte* row, byte** q8ksInputs, float* outputs, int n, int cols)SimdKernels.DotQ8_0_Q8KS_NIn(...)(lower priority — Q8_0 has little to amortize)DotQ6K_Q8K_2In(SimdKernels.cs) is the existing 2-input precedent to generalize.Then rewrite
BatchedRoutedExpertsPhase A (gate/up) and Phase C (down) to call the N-input kernels per weight row instead of loopingDispatchDot.Byte-parity (mandatory)
The A/B oracle
BatchedPrefill_BitwiseMatchesSequential_Carnice(#110) compares batched-vs-sequential prefill bitwise — it will catch any reordering. Keep the per-token top-k reduce order (downPartialreduce) unchanged. See the Q8_KS validation note (#107) and the K/V MatVecDual parity regression note for the failure mode.Expected impact
Routed MoE is ~35% of prefill wall; amortizing the Q3_K unpack across ~16 tokens/expert (512-chunk) could roughly halve the dequant cost. Best paired with #111 (trunk batching) since the trunk dominates once the MoE shrinks.
Scope (files)
src/SharpInference.Cpu/SimdKernels.cs(orSimdKernels.*.cs) — N-input dot kernelssrc/SharpInference.Engine/CudaHybridGdnForwardPass.cs—BatchedRoutedExpertsPhase A/Ctests/SharpInference.Tests.ForwardPass/— unit tests for the N-input kernels vs the single-input reference