You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The #194 weight-stationary kernels are stamped per compile-time batch capacity (2/4/8/16, CudaWsKernels.Variants); MatMulBatchedWeightStationary delegates nTok > 16 to the GEMM-N matvec, which re-streams the weight per token. So with SHARPI_MAX_BATCH > 16, decode steps at N>16 lose the weight-amortization win exactly where it matters most.
Nobody runs that configuration today (a 12 GB card can't hold 16+ full-context KV caches for Qwen3-8B at ctx 2048 anyway — ~604 MB/seq), so this is low priority, but the cliff is silent.
Options
Sub-batch split: drive N>16 as ⌈N/16⌉ weight-stationary calls — weight re-streamed once per sub-batch, still 16× better amortization than GEMM-N. Trivial, no new kernels; likely the right first move.
Extend capacities (NT=32): needs a register-pressure/occupancy check — 32 fp32 accumulators/thread on top of the dp4a working set may spill on the Q4_K 8-warp kernel.
Context
The #194 weight-stationary kernels are stamped per compile-time batch capacity (2/4/8/16,
CudaWsKernels.Variants);MatMulBatchedWeightStationarydelegatesnTok > 16to the GEMM-N matvec, which re-streams the weight per token. So withSHARPI_MAX_BATCH > 16, decode steps at N>16 lose the weight-amortization win exactly where it matters most.Nobody runs that configuration today (a 12 GB card can't hold 16+ full-context KV caches for Qwen3-8B at ctx 2048 anyway — ~604 MB/seq), so this is low priority, but the cliff is silent.
Options
SHARPI_BATCH_DECODE_GEMM=1at N=16/32 — the MMQ/GEMM fixed per-step costs (int8 activation conversion + Q6_K output fp16 dequant) may finally amortize there, and feat(cuda): continuous batching on the CUDA backend (per-sequence GPU KV caches + batched decode + packed prefill) #190's "curves converge near N=8" note predates the WS bar moving from 106→183.Either way, add an N=16 cell to
CudaBatchedDecodeBench(needs a shorter prompt/ctx so 16 caches fit in 12 GB).Refs: #194, #190.