Skip to content

Introduce migraphx.attention op in MIGraphX Dialect#2316

Draft
umangyadav wants to merge 39 commits intodevelopfrom
migraphx_attention
Draft

Introduce migraphx.attention op in MIGraphX Dialect#2316
umangyadav wants to merge 39 commits intodevelopfrom
migraphx_attention

Conversation

@umangyadav
Copy link
Copy Markdown
Member

@umangyadav umangyadav commented Mar 27, 2026

Motivation

Currently MIGraphX passes series of decomposed ops for attention. Then inside TosaToRock, rocMLIR pattern matches to find out attention and what kind of attention variant it is.

It is increasingly becoming more difficult on rocMLIR side to do pattern matching. It is much better for migraphx graph compiler to do this.

Therefore this PR introduces migraphx.attention op where it takes "features" attribute to describe what kind of attention variant it is.

Technical Details

  • Adds migraphx.attention op which has similar semantics as rock.attention op.
  • For host compilation, it decomposes migraphx.attention op which can then get lowered to linalg
  • For the GPU it directly lowers to rock.attention op
  • adds migraphx.yield op for preSoftmaxBody
  • Adds missing trait of "elementwise" on migraphx elementwise ops.
  • Adds a utility method in CAPI to construct attention op. (Subject to change)

Test Plan

  • adds some initial E2E tests for this
  • Convert all existing E2E tests into equivelant migraphx.attention to make sure functionality is preservered with accuracy.
  • Integrate with migraphx

Merge plan :

I plan to break down this large PR into several smaller ones but keeping this large draft for feedback on overall structure.

  • Add elementwise trait
  • Add migraphx.yield operator
  • Add migraphx.attention op
  • Lowering to CPU and GPU paths with initial E2E tests
  • Convert all existing E2E tests to use this new migraphx.attention op
  • Add CAPI utilities and tests

Note to MIGraphX folks

Looks at files in mlir/test/fusion/pr-e2e/migraphx-attention/ to see examples of attention kernels.

e.g. attention kernel with sliding window + kvcache + casual masking looks like following

module {
  func.func private @mlir_attention(%arg0: !migraphx.shaped<1x2x1x2xf16, 4x2x2x1>,
                                     %arg1: !migraphx.shaped<1x2x2x8xf16, 32x16x8x1>,
                                     %arg2: !migraphx.shaped<1x2x8x2xf16, 32x16x2x1>,
                                     %arg3: !migraphx.shaped<1x2xi32, 2x1>)
                                     -> !migraphx.shaped<1x2x1x2xf16, 4x2x2x1> {
    %0 = migraphx.attention %arg0, %arg1, %arg2
      current_seq_len(%arg3 : !migraphx.shaped<1x2xi32, 2x1>) {
      } features = "kvcache|causal|sliding_window" slidingWindowSize = 4
      : <1x2x1x2xf16, 4x2x2x1>, <1x2x2x8xf16, 32x16x8x1>, <1x2x8x2xf16, 32x16x2x1>
      -> <1x2x1x2xf16, 4x2x2x1>
    return %0 : !migraphx.shaped<1x2x1x2xf16, 4x2x2x1>
  }
}

SplitKV =2 with LSE and "preSoftmaxBody"

  func.func private @mlir_attention(%arg0: !migraphx.shaped<1x4x64x64xf16, 16384x4096x64x1>,
                                     %arg1: !migraphx.shaped<1x4x64x128xf16, 32768x8192x128x1>,
                                     %arg2: !migraphx.shaped<1x4x128x64xf16, 32768x8192x64x1>,
                                     %arg3: !migraphx.shaped<1x4x2x64x64xf16, 32768x8192x4096x64x1>)
                                     -> (!migraphx.shaped<1x4x2x64x64xf16, 32768x8192x4096x64x1>,
                                         !migraphx.shaped<1x4x2x64xf32, 512x128x64x1>) {
    %0, %1 = migraphx.attention %arg0, %arg1, %arg2
      pre_softmax_inputs(%arg3 : !migraphx.shaped<1x4x2x64x64xf16, 32768x8192x4096x64x1>) {
      ^bb0(%qk: !migraphx.shaped<1x4x2x64x64xf16, 32768x8192x4096x64x1>,
           %s: !migraphx.shaped<1x4x2x64x64xf16, 32768x8192x4096x64x1>):
        %scaled = migraphx.mul %qk, %s
          : <1x4x2x64x64xf16, 32768x8192x4096x64x1>, <1x4x2x64x64xf16, 32768x8192x4096x64x1>
          -> <1x4x2x64x64xf16, 32768x8192x4096x64x1>
        migraphx.yield %scaled : !migraphx.shaped<1x4x2x64x64xf16, 32768x8192x4096x64x1>
      } softmax_type = f32 features = splitkv splitKV = 2
      : <1x4x64x64xf16, 16384x4096x64x1>, <1x4x64x128xf16, 32768x8192x128x1>, <1x4x128x64xf16, 32768x8192x64x1>
      -> <1x4x2x64x64xf16, 32768x8192x4096x64x1>, !migraphx.shaped<1x4x2x64xf32, 512x128x64x1>
    return %0, %1 : !migraphx.shaped<1x4x2x64x64xf16, 32768x8192x4096x64x1>, !migraphx.shaped<1x4x2x64xf32, 512x128x64x1>
  }

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a first-class migraphx.attention op to represent attention variants (via a composable feature bitmask) in the MIGraphX dialect, reducing downstream pattern-matching complexity and enabling clearer lowering paths for host (decompose) vs GPU (lower to rock.attention).

Changes:

  • Add migraphx.attention + migraphx.yield, along with attention feature flags and verifier coverage (valid/invalid tests).
  • Add host-side decomposition of migraphx.attention inside MIGraphXTransform and a new GPU lowering pass migraphx-attention-to-rock.
  • Integrate the new pass into pipelines and add initial E2E + conversion + C API construction tests.

Reviewed changes

Copilot reviewed 38 out of 38 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
mlir/test/rocmlir-driver/pipelines.mlir Updates expected high-level pipelines to include migraphx-attention-to-rock.
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-basic.mlir Adds E2E attention test case (basic).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-softmax-f32.mlir Adds E2E attention test case (softmaxType f32).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-lse.mlir Adds E2E attention test case (LSE output).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-scale.mlir Adds E2E attention test case (pre-softmax scale/bias region via migraphx.yield).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-causal.mlir Adds E2E attention test case (causal).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-causal-scale.mlir Adds E2E attention test case (causal + pre-softmax body).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-kvcache.mlir Adds E2E attention test case (kvcache).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-kvcache-causal.mlir Adds E2E attention test case (kvcache + causal).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-kvcache-causal-prefix.mlir Adds E2E attention test case (kvcache + causal + prefix offset).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-kvcache-causal-sliding-window.mlir Adds E2E attention test case (kvcache + causal + sliding window).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-kvcache-scale.mlir Adds E2E attention test case (kvcache + pre-softmax body).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-splitkv.mlir Adds E2E attention test case (splitKV).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-gqa.mlir Adds E2E attention test case (GQA shapes).
mlir/test/Dialect/MIGraphX/ops.mlir Adds parsing/printing coverage for new ops/attrs and attention variants.
mlir/test/Dialect/MIGraphX/invalid.mlir Adds verifier negative tests for attention feature/shape/operand constraints.
mlir/test/Conversion/MIGraphXToTosa/migraphx-to-tosa-preserves-rock-attention.mlir Ensures MIGraphXToTosa doesn’t rewrite rock.attention regions produced earlier.
mlir/test/Conversion/MIGraphXAttentionToRock/attention-to-rock.mlir Adds conversion tests for --migraphx-attention-to-rock.
mlir/test/Conversion/MIGraphXAttentionDecompose/attention-decompose.mlir Adds host decomposition tests for attention variants and feature combinations.
mlir/test/CAPI/mixr_attention.c Adds C API test that constructs migraphx.attention ops.
mlir/test/CAPI/CMakeLists.txt Builds/links new mlir-mixr-attention-test.
mlir/lib/Dialect/MIGraphX/Transforms/MIGraphXTransform.cpp Implements host-side migraphx.attention decomposition (non-kernel funcs).
mlir/lib/Dialect/MIGraphX/Pipeline/Pipeline.cpp Inserts MIGraphXAttentionToRock into the high-level pipeline before TOSA/Linalg lowering.
mlir/lib/Dialect/MIGraphX/Pipeline/CMakeLists.txt Links in new MLIRMIGraphXAttentionToRock library.
mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp Adds AttentionOp verifier + feature dependency checks.
mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosaPass.cpp Marks Rock dialect / rock.attention recursively legal to preserve nested ops.
mlir/lib/Conversion/MIGraphXAttentionToRock/MIGraphXAttentionToRock.cpp New lowering pass from migraphx.attention to rock.attention for kernel funcs.
mlir/lib/Conversion/MIGraphXAttentionToRock/CMakeLists.txt Adds conversion library target for MIGraphXAttentionToRock.
mlir/lib/Conversion/CMakeLists.txt Adds new conversion subdirectory.
mlir/lib/CAPI/Dialect/MIGraphX.cpp Adds C API builder helper rocmlirMIGraphXAttentionCreate.
mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphXTypes.td Adds AttentionFeatures bit-flag enum attr.
mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td Adds migraphx.yield, migraphx.attention, and applies Elementwise trait to elementwise ops.
mlir/include/mlir/Conversion/RocMLIRPasses.td Declares migraphx-attention-to-rock pass.
mlir/include/mlir/Conversion/RocMLIRPasses.h Exposes new conversion header.
mlir/include/mlir/Conversion/MIGraphXAttentionToRock/MIGraphXAttentionToRock.h Adds public pass declaration header.
mlir/include/mlir-c/Dialect/MIGraphX.h Bumps C API version and declares attention builder + feature bit macros.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp Outdated
Comment thread mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp Outdated
@umangyadav umangyadav self-assigned this Mar 27, 2026
@dhernandez0
Copy link
Copy Markdown
Contributor

dhernandez0 commented Mar 31, 2026

Some general comments/questions:

  • why do we need rocmlirMIGraphXAttentionCreate? can't they just send IR the same way they do it for conv/gemm?
  • why limit preSoftmaxBody to elementwise only? I think they might have reshapes inside the presoftmax currently. This might be too strict.
  • IMO attention features enum adds complexity (what happens if they send features without kv-cache but they enable currentSeqLen?), is there a reason we need it? we can detect the features by what optional params they send (currentSeqLen, prefixOffset, splitKV, slidingWindowSize). The only one that doesn't have an associated param is causal, we solve this by just having a boolean param causal in rock.attention.
  • AttentionDecompose is a lot of code just for cpu lowering which is not really relevant for us. I wonder if there's an easier way to do this, such as using some existing dialect (torch/aten), converting migraphx.attention -> aten.attention?, and using their lowering to linalg?

umangyadav and others added 25 commits April 30, 2026 18:22
Per the workspace-hygiene rule, plan files and scratch notes should live
in a git-excluded directory and never be committed. Add the conventional
plans/, scratch/, and notes/ paths to .gitignore so per-session working
documents (e.g. review reports, TODO outlines) stay out of the tree.

Made-with: Cursor
…ffset

Previously the verifier accepted currentSeqLen / prefixOffset operands
of "rank 1 or 2", but MIGraphXAttentionToRock collapses 4-D Q/K/V to
Rock's [B*H, seq, dim] layout without expanding rank-1 [batch] operands
across heads. The result was either a confusing rock.attention verifier
error (currentSeqLen) or a silent OOB index/load at the gridwise lowering
(prefixOffset, which has no shape check on the Rock side).

Tighten AttentionOp::verify so the operand shape must equal Q's full
leading dims exactly (e.g. [batch] for 3-D Q, [batch, numHeads] for 4-D
Q). Producers with a per-batch sequence length must broadcast across
heads explicitly via migraphx.multibroadcast before constructing the
op; the flattened [B*H] layout remains an internal detail of the
kernel-side rock.attention and is materialised by the lowering.

Update an attention-decompose test that was using a rank-1 prefixOffset
with 4-D Q (which the new rule rejects), add positive coverage for the
canonical 4-D + rank-2 [B, H] form in ops.mlir, and add invalid-IR
tests for both currentSeqLen and prefixOffset rank-1 + multi-head Q.

Made-with: Cursor
The verifier already enforced "slidingWindowSize attribute set ⇒
sliding_window feature set" but not the inverse. With features =
"kvcache|sliding_window" and no attribute the host decompose path hits
an unreachable assertion (release builds null-deref) and the GPU
lowering silently degrades to plain KV-cache masking with no diagnostic
because MIGraphXAttentionToRock only forwards a non-null attribute to
rock.attention. Either way the user explicitly asked for sliding-window
attention and got the wrong answer or a crash.

Add a verifyAttrRequiredByFeature helper symmetric with the existing
verifyOperandRequiredByFeature and use it to reject
sliding_window-without-slidingWindowSize at the migraphx.attention op
level. Cover the new path with an invalid-IR test and tidy up a
leftover comment block from a deleted helper while we're here.

Made-with: Cursor
The existing i8 attention code path goes through migraphx.quant_dot +
migraphx.dequantizelinear + migraphx.softmax + migraphx.dot at the
graph level, fused into a rock.attention by TosaToRock's pattern
matcher. Direct construction of migraphx.attention with i8 Q/K was
type-checked as legal (AttentionQKTypes includes I8/SI8) but had no
working lowering: the MIGraphXAttentionToRock body builder forced every
linalg.generic block arg to the output element type and only emitted
arith.*f, and the host AttentionDecompose path emitted migraphx.dot
which doesn't accept i8.

Make the direct path actually work end-to-end:

- AttentionOp::verify accepts migraphx.dequantizelinear in
  preSoftmaxBody alongside Elementwise-trait ops via a small allowlist,
  and requires softmaxType to be set explicitly when Q is non-float so
  the body has a known float target type to dequantize to.
- MIGraphXAttentionToRock now derives each linalg.generic block-arg
  type from its own input (allowing i32 QK from an i8 first GEMM to
  flow into the body until a dequantize op upcasts it) and lowers
  migraphx.dequantizelinear to (cast<float>(input) - cast<float>(bias))
  * scale using the upstream mlir::convertScalarToDtype helper that
  MIGraphXToLinalg::castTensor already uses; signedness is read from
  the original MIGraphX-side operand type because MIXRShapedType::asTensor
  drops it.
- AttentionDecompose (host CPU path used by --verifier clone) emits
  migraphx.quant_dot with an i32 QK result for integer Q/K, mirroring
  what the existing TosaToRock-driven path produces. Also fix a
  pre-existing bug where the softmaxType-vs-elementType convert check
  compared against Q's original type instead of the body-yielded
  current type, which meant a redundant convert was being added when
  the body had already changed precision.

Tests:

- Update the previously-empty migraphx_attention_i8_qk in ops.mlir to
  use a meaningful body with migraphx.dequantizelinear and
  softmax_type = f32.
- Add invalid.mlir coverage for the new "softmaxType required when Q
  is integer" rule.
- Add two attention-to-rock.mlir lit tests covering the new
  arith.sitofp / arith.subf / arith.mulf scalar lowering with and
  without a dequant zero-point.
- Add five pr-e2e/migraphx-attention/mixr-attention-first-gemm-i8*.mlir
  E2E tests that mirror the existing pr-e2e/attention/* originals (kept
  alongside) but expressed as a single migraphx.attention with the
  dequant in preSoftmaxBody. They round-trip cleanly through
  --verifier clone, confirming the host quant_dot path produces
  matching numerics.

Made-with: Cursor
The verifier accepted any Q whose leading dims were divisible by K's
("queries=qd is divisible by keys=kd"), but for rank-3 Q the (batch,
numHeads) split is unrecoverable from the shape alone: <12x4x8> Q
against <4x8x16> K could mean B=1/H_q=12/H_kv=4 or B=2/H_q=6/H_kv=2 or
B=4/H_q=3/H_kv=1, each producing a different (numHeadsQ, numHeadsKV,
batch) for rock.attention. Both downstream paths fall over on this
input today: the host AttentionDecompose emits an invalid migraphx.dot
(batch dim mismatch) and the GPU MIGraphXAttentionToRock fails Rock's
verifier (numHeadsQ defaults to 1 because getNumHeads only knew the
4D shape).

Tighten AttentionOp::verify so GQA (any leading dim of Q != K's) is
only legal when Q is rank >= 4, where the heads axis is unambiguously
dim 1 by convention. Equal-heads 3D Q (numHeadsQ == numHeadsKV) is
still accepted because rock.attention is symmetric in that case
between "8 heads x 1 batch" and "1 head x 8 batches".

The previous reshape-trace fallback in MIGraphXAttentionToRock::
getNumHeads (which looked at a defining migraphx.reshape to recover
the 4D heads count from a 3D Q) is now dead code: with the verifier
rule above, a 3D Q only reaches this lowering when numHeadsQ ==
numHeadsKV, where the choice between "real heads" and "1 head, big
batch" is numerically irrelevant. Drop the trace and the
attention_gqa_3d unit test that exercised it; add a negative
invalid.mlir test confirming the new rule's diagnostic.

Add two new pr-e2e/migraphx-attention/ E2E tests
(mixr-attention-gqa-bias, mixr-attention-gqa-scale) mirroring the
existing pr-e2e/attention/* GQA tests but expressed as a single
migraphx.attention with a preSoftmax body. Originals are kept; this
broadens coverage of the direct migraphx.attention path through the
GQA lowering chain.

Made-with: Cursor
The verifier was checking the preSoftmaxBody block argument *count* but
not their *types*: a body declaring block-arg 0 as <2x4x99xf16> against
a real QK shape of <2x4x16xf16> would slip through and fail far
downstream during lowering. The splitKV path also only checked the
rank of preSoftmaxElemWiseInputs, not their dim sizes. With the M5 i8
support landed earlier in this branch, block-arg 0's element type now
also depends on whether Q is integer (the body sees the i32 quant_dot
output and is expected to dequantize it), which the old verifier did
not enforce either.

Consolidate the body checks into a single block that:

- Computes the expected QK shape once (with splitKV inflation when
  the splitkv feature is set), and the expected QK element type (Q's
  type for float Q, i32 for integer Q).
- Walks the body once with block.without_terminator(), confirming each
  op is in the existing migraphx.elementwise / dequantizelinear
  allowlist.
- For an empty body: requires a bare yield and zero block arguments.
- For a populated body, in order: arg count matches, block-arg 0 has
  exactly the QK shape and element type, each block-arg i+1 matches
  preSoftmaxElemWiseInputs[i] and that input's shape matches the QK
  shape, and the yield's value shape matches the QK shape (element
  type is free so the body can dequantize / convert before softmax).

This is a verifier-only change. All previously-valid IR still verifies
(confirmed by the targeted attention sweep: 772 tests, 743 pass, 0
unexpected failures).

Update one existing invalid.mlir test whose old splitKV-rank-only
diagnostic is now subsumed by the more specific block-arg-0 check, and
add four new negative tests covering wrong block-arg shape, wrong
block-arg element type for i8 Q (signedness), block-arg i+1 vs input
type mismatch, and input shape vs QK shape mismatch.

Made-with: Cursor
The host AttentionDecompose path defaulted softmaxElemType to the value
entering softmax (Q's element type for an empty body, the body's yielded
type otherwise) while rock::gridwise_attention_accel defaulted to V's
element type. When Q.elemType != V.elemType and softmaxType was unset,
the two paths would silently softmax in different precisions and
--verifier clone would diverge. The previous M5-era check guarded the
narrower "Q is integer-typed" subset of this footgun but missed the
all-float-but-mismatched-precisions case (e.g. bf16 Q/K with f32 V).

Replace the integer-Q-specific check with a more general one in
AttentionOp::verify: when the value entering softmax doesn't already
have V's element type, the producer must set softmaxType explicitly so
the lowering can insert convert ops on either side of softmax. The new
rule is strictly stronger -- it still rejects all the cases the old
rule rejected (integer Q with empty body still fails because i32 != V),
correctly accepts the case the old rule wrongly rejected (integer Q
with a body that dequantizes to V's element type), and additionally
catches the float-but-mismatched case the old rule missed. The check
moved to after the body validation so it can read the yielded element
type.

Align the host AttentionDecompose default with rock's gridwise default
by reading op.getSoftmaxType().value_or(vType.getElementType()),
mirroring rock::gridwise_attention_accel:2147 line for line. With the
new verifier rule above this is provably safe -- either softmaxType is
explicitly set (and value_or returns it) or the value entering softmax
already matches V (so the default lands on the same value the old
"qkCurrentElemType" default would have picked).

Add a TODO at the host's second GEMM noting that migraphx.dot
accumulates in the operand element type rather than promoting to the
softmax type the way the GPU mfma path does. For long sequences this
can produce slightly less accurate CPU reference values than the GPU;
widening the dot's accumulator (or splitting into f32 partial sums +
downcast) would close the gap.

Update the M5-era invalid-IR test whose old diagnostic message changed
and add two new negative tests:
- Float Q (bf16) and V (f32) with no body and no softmaxType -- the
  C2.4 footgun, now rejected at the verifier.
- Body that yields a different element type than V (f32 vs f16) with
  no softmaxType -- same rule, populated-body branch.

Add a positive E2E test mirroring the new shape: bf16 Q/K with f32 V
and an explicit softmax_type = f32, batch=1 heads=4. Both the host
CPU reference and the GPU kernel agree within bf16 precision (verified
via --verifier clone).

Made-with: Cursor
The pr-e2e/attention/ directory has a family of bare matmul+softmax+matmul
tests with non-power-of-2 padded shapes (head_dim=3, seq_len=7) and
various pre-softmax bodies (none, bias, scale, scale+bias, scale+exp+bias,
multi-op tree, full-shape constant scale). These tests cover the
TosaToRock attention pattern matcher; they did not exercise the
migraphx.attention direct path.

Add nine pr-e2e/migraphx-attention/ counterparts that express the same
shapes and numerics as a single migraphx.attention with a preSoftmax
body. Originals are kept; the new tests broaden coverage of the direct
lowering path through MIGraphXAttentionToRock and the host
AttentionDecompose:

- mixr-attention-square-3d (mirror of mixr-attention.mlir, square 3D
  64x64 shape, no body)
- mixr-attention-padded (3D 7x3 shape, no body)
- mixr-attention-padded-bias (1-input add body)
- mixr-attention-padded-scale (1-input mul body)
- mixr-attention-padded-scale-and-bias (2-input mul+add body)
- mixr-attention-padded-const-scale (full-shape migraphx.literal scale,
  shape rewritten from the original 1xf32+multibroadcast to a full-shape
  literal so the materialised scale matches the QK shape)
- mixr-attention-padded-scale-cross (cross-attention seqQ=128, seqK=27,
  perf_config forwarded onto migraphx.attention)
- mixr-attention-padded-scale-bias-exp (3-op body: mul, exp, add)
- mixr-attention-padded-complex-tree-elemwise (4-op tree body combining
  auxiliary inputs at leaf level then scaling/biasing the QK output)

The padded-scale-bias-exp test exercises migraphx.exp in the body.
That op already had the Elementwise trait (so the verifier accepted
it), but the MIGraphXAttentionToRock scalar lowering only handled
mul/add/sub/neg/dequantizelinear. Add a single math.exp scalar lowering
plus the matching MLIRMathDialect link library and dependentDialects
entry; all other ops in the new tests already had scalar lowerings.

All nine new tests pass end-to-end on GPU and verify against the host
CPU reference (--verifier clone). Targeted attention sweep is
unchanged otherwise: 782 tests, 753 pass, 4 expected fails (pre-existing
GPU-required), 25 unsupported, 0 unexpected failures.

Made-with: Cursor
Two follow-ups from the consolidated review:

R.4: the host applySlidingWindowMask computes
   lowerBound = currentSeqLen + (-windowSize)
without clamping to zero. It produced correct numerics today only
because tosa.greater (and migraphx.greater) treat their signed integer
operands as signed, so a negative lower bound compared against a
non-negative column index returns false and no spurious mask is
applied. The GPU side (rock::gridwise_attention_accel:1844) explicitly
emits arith.maxsi against zero. Mirror that here with a where(0 >
lowerBound, 0, lowerBound) clamp so the host IR no longer depends on
the signed-comparison convention and the two paths use the same
documented invariant.

C/N: MIGraphX.td documented the kvcache mask predicate as
   "key_index >= currentSeqLen"
but both host (applyKVCacheMask) and GPU
(GridwiseGemmToBlockwise.cpp:1318 ugt mIndex, currentSeqLen) actually
use strict greater. Update the doc text to match the implementation
and clarify that currentSeqLen is the index of the last valid key
position (range [0, currentSeqLen] inclusive), matching PyTorch SDPA
and FlashAttention. Drop a stale "TODO: verify that negative prefixes
work" comment in the same block whose answer the surrounding
prefix_offset description already gives ("the offset can be positive
or negative").

Strip the trailing blank line at the end of invalid.mlir flagged by
git diff --check.

Made-with: Cursor
The MIGraphXAttentionToRock body lowering only mapped 5 migraphx ops
to scalar arith/math (mul/add/sub/neg/exp/dequantizelinear), but
AttentionOp::verify accepted any op carrying the Elementwise trait
plus migraphx.dequantizelinear. Anything outside the lowering's small
set (e.g. migraphx.div in a body) passed the verifier and then failed
deep in the conversion pass with an opaque "unsupported migraphx op"
diagnostic far from the user's migraphx.attention. Also, the explicit
five-op subset constrained what producers could put in the body
without good reason.

Extend lowerMIGraphXElementwiseToScalar to cover the full set of
scalar-lowerable migraphx elementwise ops, mirroring
MIGraphXToLinalg's ElementwiseConverter / GenericElementwiseOpConverter
coverage so host CPU and GPU body lowerings stay numerically aligned.
The new mapping is:

  Binary float : add (addf), sub (subf), mul (mulf), div (divf),
                 pow (math.powf)
  Unary float  : neg (negf), abs (math.absf), ceil (math.ceil),
                 floor (math.floor), exp (math.exp), log (math.log),
                 sqrt (math.sqrt), tanh (math.tanh), erf (math.erf)
  Composed     : recip (1/x), relu (max(0,x)), sigmoid (1/(1+exp(-x)))
  Selection    : where (i8 cond -> arith.select via i1 cast)
  Cast         : convert (mlir::convertScalarToDtype, signedness-aware)
  Quant        : dequantizelinear (unchanged)

Tighten AttentionOp::verify to use an explicit allowlist matching the
new lowering set exactly. The previous Elementwise-trait test let
greater/equal/clip/rsqrt slip through unsupported; the explicit
allowlist makes the verifier-vs-lowering contract crisp and adding a
new body op is a one-line change in two coupled places. Update the
diagnostic to point at lowerMIGraphXElementwiseToScalar so the next
maintainer knows where to look.

Tests:

- Update two pre-existing invalid.mlir tests whose assertions
  referenced the old "must only contain elementwise migraphx ops"
  wording; the cases (migraphx.dot and migraphx.reduce_sum in body)
  still reject, just under the new message.
- Add attention_extended_body_ops to attention-to-rock.mlir exercising
  div/pow/neg/abs/exp/log/sqrt/tanh/erf/recip/relu/sigmoid/where/
  convert in a single body, with FileCheck patterns confirming each
  maps to its expected arith/math scalar op.

Verification: targeted attention sweep is unchanged at 782 tests,
753 pass, 4 expected fails (pre-existing GPU-required), 25
unsupported, 0 unexpected failures.

Made-with: Cursor
Two contracts were duplicated across the migraphx.attention lowering
chain: the QK output element type rule (Q.elemType for float Q, i32 for
integer Q from a quantized first GEMM), and the closed set of body ops
that MIGraphXAttentionToRock::lowerMIGraphXElementwiseToScalar can
scalar-lower. The element-type rule was open-coded in three places
(AttentionOp::verify, MIGraphXTransform's host AttentionDecompose, and
rocmlir-gen's attention generator) and the body allowlist was both a
local lambda in the verifier and an implicit dispatch table in the
lowering, with only a comment cross-referencing them.

Add a header-only mlir/include/mlir/Dialect/MIGraphX/IR/AttentionUtils.h
exposing two small inline helpers:

- computeAttentionQKElemType(qElemType, ctx): encodes the QK-output
  element type rule.
- isAllowedInPreSoftmaxBody(op): the closed allowlist of body ops with
  a matching scalar lowering.

Refactor the three element-type call sites (verifier / host decompose /
rocmlir-gen) and the verifier allowlist to consume the shared helpers,
removing the local lambdas and ternaries. The header's docstring
calls out the lock-step coupling between isAllowedInPreSoftmaxBody and
the lowering's dispatch table so future maintainers know to update
both places when adding a new body op.

Items NOT extracted (deliberately):
- expectedQKShape: the verifier and host decompose work from different
  inputs (pre- vs post-splitKV-reshape types) so unifying would obscure
  the splitKV reshape that's already happening.
- getNumHeads vs TosaToRock's getNumHeadsGQA: same name, different
  operations (read dim 1 of a 4D migraphx type vs walk back through
  tensor.collapse_shape on the tosa side).
- Contiguous-strides construction: single use, no duplication.
- isMixrUnsignedInt: single use; the inline form is clearer at the
  call site.

Verification: targeted attention sweep is unchanged at 782 tests, 753
pass, 4 expected fails (pre-existing GPU-required), 25 unsupported,
0 unexpected failures. Pure NFC -- no behavior change, just shifts
the rule definitions to one place.

Made-with: Cursor
Three refactors that don't change behaviour:

- Validate splitKV up-front in AttentionOp::verify so result/LSE/QK shape
  construction can use a single validated effectiveSplitKV value, instead
  of computing it twice with subtly different gates (>1 vs. just present).
- Drop the unreachable `splitKV <= 0` check: the orphan-attr path now
  rejects splitKV-without-feature, and the splitkv-feature path requires
  splitKV > 1, so a zero/negative value is always caught earlier with a
  more specific diagnostic.
- Replace three open-coded contiguous-stride loops in MIGraphXTransform's
  AttentionDecompose with the existing makeContiguousType helper.

Co-authored-by: Cursor <cursoragent@cursor.com>
…ndow clamp

Two small follow-ups for the migraphx.attention preSoftmaxBody and
sliding-window mask lowering:

- Assert in AttentionToRockPattern's body builder that any op the
  verifier admits via isAllowedInPreSoftmaxBody is also lowered by
  lowerMIGraphXElementwiseToScalar. The two lists were already
  documented as a lock-step pair, but nothing was catching divergence
  beyond a structured runtime error. The assertion fires in debug
  builds the moment the dispatcher is missing a case the verifier
  approves; release builds continue to surface the same
  "unsupported migraphx op in preSoftmaxBody" error.
- Add an E2E test pinning currentSeqLen to [0, 2] with
  slidingWindowSize = 4 so the unclamped lowerBound formula
  (currentSeqLen - windowSize) is always negative. Both host and GPU
  apply max(0, .) and produce the same numerics; --verifier=clone
  guards against a future regression where one side stops clamping.

Co-authored-by: Cursor <cursoragent@cursor.com>
Two related verifier additions that close gaps the previous review
flagged:

- Forbid features = "splitkv | sliding_window" outright. splitkv
  reshapes the body to operate on per-chunk [seqK / splitKV] columns
  while sliding_window's lower bound lives in absolute K-position
  space, and the gridwise lowering's mask uses that absolute index
  too. Nothing in the current pipeline reconciles those two views
  and no test exercises the combination, so reject it explicitly
  rather than have host and GPU silently disagree. Adds a new
  verifyFeatureMutualExclusion helper for the rejection.
- Spell out that 'sliding_window' requires currentSeqLen, on top of
  its existing dependency on 'kvcache'. Today the kvcache requirement
  catches the missing operand transitively; making the rule explicit
  protects the sliding-window path if the kvcache linkage ever
  changes.

Add an invalid test for the splitkv + sliding_window rejection.

Co-authored-by: Cursor <cursoragent@cursor.com>
Two small cleanups in the migraphx.attention lowering and host
decomposition:

- Drop convertMIXRToTensor in MIGraphXAttentionToRock. It was a
  one-line wrapper around migraphx::AsLogicalShapeOp::create used in
  exactly one call site; the rest of the file already calls the op
  builder directly, so inline the only caller.
- Tighten the host LSE reshape in AttentionDecompose. Pull the target
  shape directly from the op's lse result type and name the current
  element type, so the data flow is clearer than the earlier "build
  a vector from the verifier-output shape and walk it twice" pattern.

Co-authored-by: Cursor <cursoragent@cursor.com>
Two correctness bugs that surface when splitKV interacts with the
preSoftmax body or attention masks:

- AttentionToRockPattern unconditionally set
  preSoftmaxHasSplitKVTransforms = false. The verifier requires
  preSoftmax inputs to live in split space ([B*H*splitKV, SeqQ,
  SeqK/splitKV]), but without this attribute the gridwise lowering
  leaves the GEMM0 output in the un-split shape and the linalg.generic
  body iterates the two operands in disagreeing index spaces. Set the
  attr to true when splitKV > 1 and the body has inputs.
- Host AttentionDecompose's mask helpers (causal / kvcache / sliding
  window) used qkShape[rank-1] as their column iota, which is the
  per-chunk seqK / splitKV when the body operates in split space.
  Masks compare against the original currentSeqLen / row index, so
  per-chunk column indices mask the wrong keys for non-zero split
  chunks. Plumb splitKV through createBroadcastColIndices and emit
  splitIdx * seqKPerSplit + localCol when splitKV > 1, matching the
  GPU's absolute mIndex logic.

Add an IR-level FileCheck test exercising the global-index path for
splitkv + kvcache, and document the still-loose splitkv-scale e2e
threshold (the host's per-chunk softmax decomposition diverges from
GPU's combined-via-LSE recombination irrespective of these fixes).

Co-authored-by: Cursor <cursoragent@cursor.com>
Five verifier rules that previously slipped past the migraphx-side
verifier and surfaced as opaque downstream errors (or, worse, invalid
IR like arith.addf-on-i32):

- Reject integer arithmetic in the preSoftmax body (G3). Body block
  arg 0 is i32 for integer-Q attention (output of quant_dot), but the
  scalar lowering only emits arith.{add,mul,sub,div,...}f. Reject any
  body op other than dequantizelinear / convert / where whose operands
  are integer; the body must dequantize or convert before arithmetic.
- Restrict softmaxType to {f16, bf16, f32} (G4). f64 and other exotic
  floats verified at the migraphx layer but tripped a less helpful
  rock-internal verifier later.
- Constrain attention result to AttentionVTypes and lse to
  {f16, bf16, f32} (G5). The TableGen had AnyMIXRShaped on both, so
  e.g. an i32 attention result verified happily.
- Require Q rank exactly 4 when GQA is active (G6). Rank 3 was
  already rejected; rank 5+ now is too because neither the host
  broadcastForGQA nor the GPU getNumHeads detector handle higher ranks
  - they assumed rank 4 and produced wrong index spaces or fell back
  to numHeads = 1.
- Require Q / K / V rank >= 3 (G7). Rank 2 attention slipped past the
  migraphx verifier and failed to legalise through rock-gridwise with
  an opaque "failed to legalize rock.attention" error. Rock requires
  a leading batch dim, so reject rank 2 here with a clear diagnostic.

The previous commit already made splitkv + sliding_window work
correctly (the global-col-index fix covers both kvcache and
sliding_window masks under splitKV), so drop the now-unnecessary
splitkv vs sliding_window mutual exclusion and its helper.

Add invalid tests for each new rule.

Co-authored-by: Cursor <cursoragent@cursor.com>
The earlier comment blamed AttentionDecompose's per-chunk softmax for
the host-vs-GPU divergence on this test. After comparing a tiny
hand-decomposed splitkv attention (the same pattern develop's
pr-e2e/attention/mixr-attention-flash-decoding.mlir uses) against
this branch's migraphx.attention path, both produce the same
per-element divergence and the GPU output's second split chunk is
partially uninitialised (NaN) on the smaller shape. The bug is in
rock::gridwise_attention_accel rather than in this branch.

Update the in-test comment to point at the right place so a future
fix targets the right pass; thresholds stay loose until the rock
kernel writes both partials cleanly.

Co-authored-by: Cursor <cursoragent@cursor.com>
User pointed out that splitkv testing needs gridSize * splitKV
workgroups. Confirmed via the rock-side formula
computeGridSizeAttentionGemmElmtGemm (gridSize = (gemm0N / NPerBlock)
* gemm0G * splitKV) - the launch is correct, both per-workgroup
coords are unique (g, n_block, split_block) via makeGxNGridLayout,
and the rock pipeline does launch one WG per chunk. The bug is
elsewhere in gridwise_attention_accel: chunk 1 keeps stale memory
even though its workgroup runs. Update the comment to reflect this
so a future rock-side fix targets the right code path.

Co-authored-by: Cursor <cursoragent@cursor.com>
User asked whether splitkv needs an output prefill. Confirmed via
deep dive: prefill alone wouldn't fix this. The actual bug is in
rock::gridwise_attention_accel.postProcessFirstGemm
(GridwiseGemmToBlockwise.cpp:1636) - the ThreadwiseReadIntoOp that
reads preSoftmaxElemWiseInputs gets extra indices
{g_block, m_block, n_block, tid}, missing split_block. The body
input is sized [B*H*splitKV, ...] but only the (g, n) coords of the
gemm0 grid are threaded into the read address, so both chunks of
every (g, n) pair read the same body-input slot. Chunk 1 then
computes with the wrong inputs and the output write
(also missing split_block in gridCoordsGemm1) lands on chunk 1's
slot of the output - which holds whatever uninitialised memory
came from the harness alloc.

Confirmed:
- grid is sized correctly (gridSize * splitKV) and bid→(g, n, split)
  mapping is unique.
- splitkv without body passes 0.0005 threshold.
- splitkv with a body that has zero extra inputs (just a convert)
  passes 0.0005 threshold even at our shape.
- splitkv + body + extra inputs reproduces the divergence identically
  on a hand-written develop-style decomposition.

Update the in-test comment to reflect this finding so the rock-side
fix targets the right code path.

Co-authored-by: Cursor <cursoragent@cursor.com>
`createSplitKVTransformsForGemm0Out` returns a stack pointing from
body-input shape `[B*H*splitKV, SeqQ, SeqK_chunk]` (top) down to
gemm0-buffer shape `[B*H, SeqQ, SeqK]` (bottom). The previous
prepend-under-`linalgGridSubTileMaps` left the composed chain ending in
gemm0-buffer space, so when `postProcessFirstGemm` then composed
`linalgToOtherInputMaps` for each `preSoftmaxElemWiseInput` the bottom
no longer matched the user input memref. The resulting
`rock.threadwise_read_into` produced wrong addresses for the extra body
inputs whenever splitKV > 1.

The bug was silent for bodies with no extra inputs (e.g. develop's
flash-decoding test), because the composed chain is only consumed in
the `for (genOpInput)` loop in `postProcessFirstGemm`. With an extra
input (e.g. our splitkv-scale test) chunk 1 produced large numerical
divergence and unreliable LSE.

Invert `splitKVTransforms` before prepending so the composed chain
ends in body-input space, which is what
`postProcessFirstGemm` expects for each `preSoftmaxElemWiseInput`.

Also restore tight `RMS_threshold = 0.002 / relDiff_threshold = 0.0005`
on `mixr-attention-splitkv-scale.mlir` (was loosened earlier as a
workaround) and drop the long workaround comment.

Co-authored-by: Cursor <cursoragent@cursor.com>
Locks in the transform-stack direction the previous commit fixed for
`postProcessFirstGemm`. The test feeds `rock-gridwise-gemm-to-blockwise`
a `rock.gridwise_attention_accel` op with `splitKV > 1` AND a body that
consumes an extra `preSoftmaxElemWiseInput`, then checks that the
resulting `rock.threadwise_read_into` for the user input ends in two
leaf transforms going gemm0-shape -> body-shape:

  Merge{splitKV, seqK_chunk} ["seqK"] -> ["splitKV", "seqK_chunk"]
  Unmerge{batch, splitKV}    ["batch", "splitKV"] -> ["batch"]

These are the inverse of `createSplitKVTransformsForGemm0Out`'s output
and only appear after the inversion in the previous commit. Reverting
the inversion flips both factor orderings and the test fails.

This was the gap that allowed the underlying bug to land: the existing
splitKV tests in `gridwise_attention_accel_lowering.mlir` use empty
`preSoftmaxOps`, and the existing body tests use `splitKV = 1`, so no
test exercised the buggy code path before.

Co-authored-by: Cursor <cursoragent@cursor.com>
Three small simplifications surfaced in the post-fix branch review:

1. `applySlidingWindowMask`: drop the
   `static_cast<uint64_t>(-windowSize)` + computed `signedSemantics`
   dance. `currentSeqLen`'s element type is restricted to `[I32, SI32]`
   by the op definition, so a plain
   `APInt(width, -windowSize, /*isSigned=*/true)` is always correct
   and reads more directly.

2. `AttentionToRockPattern::matchAndRewrite`: replace the
   `&& !srcRegion.empty()` defensive check with an assertion.
   `SingleBlockImplicitTerminator` on `migraphx.attention` guarantees
   the body has a block, so the runtime check was always true.

3. `AttentionToRockPattern::matchAndRewrite`: extract a
   `getCollapseToLastDimReassoc(rank)` helper alongside the existing
   `getLeadingDimReassoc`, and use it for the LSE expand-shape (rank
   -> {{0..rank-2}, {rank-1}}). Replaces an inline reassociation
   builder at the LSE expand site.

No behavior change. All attention lit/E2E tests pass; full
check-rocmlir is green (1370/1370).

Co-authored-by: Cursor <cursoragent@cursor.com>
…fier and decompose

Verifier (`AttentionOp::verify`):
- Factor `qBatch + (splitKV if > 1) + trailing` shape construction into
  a `makeAttnShape` helper, used at 3 sites (result, LSE, body QK shape).
- Factor the `size != size || !std::equal(...)` shape mismatch check +
  diagnostic format into a `checkAttnShape` helper, used at 2 sites
  (result, LSE).
- Hoist `seqK = kShape[kRank - 1]` once at the top of `verify()` and
  reuse for the splitKV-divisibility check, the QK shape construction,
  and the sliding-window max-seq-len bound (was duplicated 3x).

Net: ~30 lines shorter, consistent diagnostic format, and the
splitKV-inflation rule now lives in one place that the next reader
can scan in 4 lines instead of 4 inlined blocks.

Host decompose (`MIGraphXTransform.cpp`):
- Replace the manual permutation loops in `splitKVReshapeK` with the
  upstream `applyPermutation` helper from
  `mlir/Dialect/Utils/IndexingUtils.h`. Drops 2 explicit loops and an
  intermediate `kSplitStrides` ArrayRef.
- Add a small `createBroadcastIntScalar` helper that wraps the
  APInt + DenseElementsAttr + createBroadcastScalar boilerplate. Used
  at 2 sites (`createBroadcastColIndices`'s `seqKPerSplitConst` and
  `applySlidingWindowMask`'s `zeroI32` clamp).
- Tighten `Type si32 = getSi32Type(...)` to
  `IntegerType si32 = ...` across the file (`getSi32Type` already
  returns `IntegerType`); removes a `cast<IntegerType>` at the
  zero-clamp call site.

Drive-by: clang-format catches a stray blank line in
`mlir-c/Dialect/MIGraphX.h` and a line-break in `RocMLIRPasses.td`'s
`dependentDialects` list, both pre-existing in the branch.

No behavior change. All attention lit/E2E tests pass; full
check-rocmlir is green (1370/1370).

Co-authored-by: Cursor <cursoragent@cursor.com>
Five attention findings from a follow-up review:

1. (High) Reject dynamic dims. The verifier and downstream lowering do
   static shape arithmetic (% on seqK, leading-dim collapse, body QK
   shape construction). Allowing dynamic dims through migraphx.attention
   produced silently-broken IR. Add an up-front static-shape check on
   queries/keys/values/result/lse/currentSeqLen/prefixOffset and each
   preSoftmaxElemWiseInput, mirroring MultiBroadcastOp::verify's
   existing convention for this dialect.

2. (High) Require Q.elemType == K.elemType. The op definition let Q and
   K independently pick from `[F32, F16, BF16, I8]`. Mixed Q/K element
   types passed the verifier and fed into a `migraphx.dot` with
   mismatched operand types or routed via Q's element type alone in
   the GEMM-op selection. Reject mixed Q/K up front.

3. (Medium) Widen the host's second GEMM to softmaxType. The CPU
   reference's `migraphx.dot(softmax, V)` accumulated in V's element
   type while the GPU MFMA path keeps gemm1's accumulator at
   softmaxType (typically f32) and downcasts only at the end. For long
   sequences the CPU reference would diverge from the GPU. Convert V
   to softmaxType, run the dot in softmaxType, downcast the result to
   V's element type. The existing TODO at the second-GEMM site is now
   resolved and removed.

4. (Medium) Require lse.elemType == effective softmaxType. Both host
   and GPU compute LSE intermediates (reduce_sum, log, max, add) in
   the softmax type. Allowing an LSE result wider than softmaxType
   would silently round-trip through narrower intermediates. With
   softmaxType absent the effective type is V's element type; in both
   cases the LSE output must match.

5. (Low) Tighten kvcache-scale-lse threshold. After (3) the
   pre-existing `RMS_threshold = 0.03 / relDiff_threshold = 0.2`
   workaround on `mixr-attention-kvcache-scale-lse.mlir` is no longer
   needed; the test now passes at `RMS_threshold = 0.002 /
   relDiff_threshold = 0.005`, matching sibling attention tests.

Drive-bys:
- Polish the rank-2-rejection comment that read as if rank-2 still
  slips past (it does not).
- Fix the long-standing `stlide` -> `stride` typo in
  MIGraphXTypes.td.

Tests:
- 4 new invalid tests: dynamic Q, dynamic preSoftmax input,
  Q/K type mismatch, lse vs softmax mismatch.
- Update 2 ops.mlir tests and 1 attention-to-rock test to add the
  now-required `softmax_type = f32` for V.f16 + LSE.f32 cases.
- Replace `decompose_lse_type_convert` with
  `decompose_widened_second_gemm` exercising the widening path
  (V f16 + softmax_type=f32 + LSE f32 emits a convert before and
  after the second dot).

Verification: full check-rocmlir passes (1370/1370, 0 fail, 4 XFAIL).
Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants