Skip to content

perf(cuda): batched-decode routing above WS capacity — N>16 silently falls back to GEMM-N matvec (#194 follow-up) #198

@pekkah

Description

@pekkah

Context

The #194 weight-stationary kernels are stamped per compile-time batch capacity (2/4/8/16, CudaWsKernels.Variants); MatMulBatchedWeightStationary delegates nTok > 16 to the GEMM-N matvec, which re-streams the weight per token. So with SHARPI_MAX_BATCH > 16, decode steps at N>16 lose the weight-amortization win exactly where it matters most.

Nobody runs that configuration today (a 12 GB card can't hold 16+ full-context KV caches for Qwen3-8B at ctx 2048 anyway — ~604 MB/seq), so this is low priority, but the cliff is silent.

Options

  1. Sub-batch split: drive N>16 as ⌈N/16⌉ weight-stationary calls — weight re-streamed once per sub-batch, still 16× better amortization than GEMM-N. Trivial, no new kernels; likely the right first move.
  2. Extend capacities (NT=32): needs a register-pressure/occupancy check — 32 fp32 accumulators/thread on top of the dp4a working set may spill on the Q4_K 8-warp kernel.
  3. Compute-bound crossover: re-bench SHARPI_BATCH_DECODE_GEMM=1 at N=16/32 — the MMQ/GEMM fixed per-step costs (int8 activation conversion + Q6_K output fp16 dequant) may finally amortize there, and feat(cuda): continuous batching on the CUDA backend (per-sequence GPU KV caches + batched decode + packed prefill) #190's "curves converge near N=8" note predates the WS bar moving from 106→183.

Either way, add an N=16 cell to CudaBatchedDecodeBench (needs a shorter prompt/ctx so 16 caches fit in 12 GB).

Refs: #194, #190.

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions