Skip to content

[DRAFT] Add register-only cross-lane reduction for attention#2359

Open
stefankoncarevic wants to merge 8 commits intoROCm:developfrom
stefankoncarevic:blockwise-reduce-crosslane
Open

[DRAFT] Add register-only cross-lane reduction for attention#2359
stefankoncarevic wants to merge 8 commits intoROCm:developfrom
stefankoncarevic:blockwise-reduce-crosslane

Conversation

@stefankoncarevic
Copy link
Copy Markdown
Contributor

Motivation

In the current implementation, configurations with partialR=2 and blockSize <= nonReductionDimSizeProduct (NR-Large-Tree path) perform reduction entirely through LDS, even though only a single 2-way reduction is needed. This is suboptimal because:

A 2-way reduction can be done with a single cross-lane register exchange
The LDS store + barrier adds latency that is disproportionate to the work
Modern AMD GPUs provide efficient cross-lane instructions that can replace this pattern

Technical Details

Test Plan

Test Result

Submission Checklist

…per extraction

Restructure the blockwise reduce rewrite pattern in BlockwiseGemmToThreadwise.cpp
to improve clarity, maintainability, and enable DPP-based reductions via
gpu.SubgroupReduceOp.

Shuffle decision logic:
- Introduce has2DThreadLayout guard (mTidPerWave > 0 && nTidPerWave > 0) to
  clearly separate GEMM-style 2D thread layouts from general cases
- Path 1 (Shuffle+DPP): activates when blockSize > nrDimProduct and the
  per-thread subtile is [1,1] with rDim == 1, using gpu.shuffle to transpose
  data from WMMA/MFMA strided layout into contiguous DPP-compatible layout
- Path 2 (Serial XOR): activates when blockSize <= nrDimProduct, performing
  log2(rDim) XOR butterfly reduction steps within a wave at stride nTidPerWave
- Initial LDS store is deferred: only performed when neither shuffle path applies,
  avoiding unnecessary LDS traffic for shuffle-eligible configurations

Parallel reduction with DPP:
- Use gpu.SubgroupReduceOp with cluster_size for DPP-eligible reductions
  (power-of-2 active threads, cluster_size <= waveSize)
- Only the reduction group leader (rtid == 0) writes the result back to LDS,
  followed by a barrier and broadcast read
- Use bitwise AND/SHRU for thread ID decomposition (rtid, nrtid) on the DPP
  path and for power-of-2 non-reduction dimensions; fall back to DIV/REM
  for non-power-of-2 cases
- Force scalar accumulation (vectorLen = 1) during threadwise pre-reduction
  on the DPP path to ensure correct element-wise reduction before SubgroupReduceOp

Helper extraction:
- getPerWaveThreadCounts: promote to static member function; extracts m_tid and
  n_tid counts from the tid slice view Merge transform
- shuffleRearrangeForDPP: encapsulates the gpu.shuffle-based transposition from
  strided WMMA/MFMA layout to contiguous DPP layout
  (sourceLane = (lane % clusterSize) * stride + lane / clusterSize)
- readReducedResultsFromLDS: consolidates the repeated pattern of barrier +
  ThreadwiseReadInto from LDS into output registers (and optional extra output)

Tree reduction path:
- Retained as fallback for non-DPP-eligible configurations
  (non-power-of-2 thread counts or cluster_size > waveSize)
- Scope ceilPowerOf2 computation and treeMaxActiveThreads naming to this path

New test: blockwise_reduce_dpp_cluster_sizes.mlir
- Integration test covering DPP reduction with cluster sizes 2, 4, 8, 16, 32, 64
- Validates both sum (rand=none, all ones) and max (rand=fixed) reductions
- All test configurations use blockSize <= waveSize to ensure single-wave
  execution on both RDNA (waveSize=32) and CDNA (waveSize=64)
- cluster_size=64 falls back to tree reduction on RDNA since 64 > waveSize=32
…ion kernels

Remove the shuffle+DPP transpose path and serial XOR butterfly reduction
from BlockwiseBroadcastReduceOp lowering. These paths used gpu.shuffle
to rearrange data between WMMA/MFMA strided layout and contiguous DPP
layout, adding complexity without consistent performance benefit.
The DPP reduction path now uses gpu::SubgroupReduceOp directly with
cluster_size, which handles cross-lane communication within a wavefront
without requiring explicit data rearrangement through shuffle or LDS.
Key changes:
- Remove shuffleRearrangeForDPP() and all shuffle optimization logic
  (canUseShuffleOptimization, canUseSerialShuffle, XOR butterfly)
- Restrict DPP activation to partial_r > 2, as configurations with
  partial_r = 2 do not benefit from DPP due to insufficient work to
  amortize the instruction overhead; these fall back to LDS-Tree
- Remove forced scalar vectorization for DPP threadwise reduction
- Simplify LDS store to be unconditional (no longer skipped by shuffle)
…rch DB for wave size

- Change canUseDPP condition from >= to == for blockSize vs
  clusterSize * nonReductionDimSizeProduct to prevent potential
  out-of-bounds LDS writes by extra threads when blockSize exceeds
  the exact thread count needed for the DPP layout.
- Replace hard-coded chipset major version heuristic in
  SubgroupReduceToDPP with rock::lookupArchInfo(chip).waveSize
  for more robust subgroup size derivation.
- Update lowering_blockwise_broadcast_reduce test to use dimensions
  where blockSize == clusterSize * nrDimProd (8 == 2 * 4).
Introduce a register-only reduction path using v_permlanex16_var_b32
(GFX12+) for blockwise broadcast-reduce when partialR=2 on wave32
architectures. This avoids the initial LDS store + barrier by performing
the 2-way reduction directly in registers before writing to LDS.
@stefankoncarevic stefankoncarevic changed the base branch from dpp-refactor-blockwise-reduce to develop April 27, 2026 11:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant