Introduce migraphx.attention op in MIGraphX Dialect#2316
Draft
umangyadav wants to merge 39 commits intodevelopfrom
Draft
Introduce migraphx.attention op in MIGraphX Dialect#2316umangyadav wants to merge 39 commits intodevelopfrom
migraphx.attention op in MIGraphX Dialect#2316umangyadav wants to merge 39 commits intodevelopfrom
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR introduces a first-class migraphx.attention op to represent attention variants (via a composable feature bitmask) in the MIGraphX dialect, reducing downstream pattern-matching complexity and enabling clearer lowering paths for host (decompose) vs GPU (lower to rock.attention).
Changes:
- Add
migraphx.attention+migraphx.yield, along with attention feature flags and verifier coverage (valid/invalid tests). - Add host-side decomposition of
migraphx.attentioninsideMIGraphXTransformand a new GPU lowering passmigraphx-attention-to-rock. - Integrate the new pass into pipelines and add initial E2E + conversion + C API construction tests.
Reviewed changes
Copilot reviewed 38 out of 38 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| mlir/test/rocmlir-driver/pipelines.mlir | Updates expected high-level pipelines to include migraphx-attention-to-rock. |
| mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-basic.mlir | Adds E2E attention test case (basic). |
| mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-softmax-f32.mlir | Adds E2E attention test case (softmaxType f32). |
| mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-lse.mlir | Adds E2E attention test case (LSE output). |
| mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-scale.mlir | Adds E2E attention test case (pre-softmax scale/bias region via migraphx.yield). |
| mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-causal.mlir | Adds E2E attention test case (causal). |
| mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-causal-scale.mlir | Adds E2E attention test case (causal + pre-softmax body). |
| mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-kvcache.mlir | Adds E2E attention test case (kvcache). |
| mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-kvcache-causal.mlir | Adds E2E attention test case (kvcache + causal). |
| mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-kvcache-causal-prefix.mlir | Adds E2E attention test case (kvcache + causal + prefix offset). |
| mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-kvcache-causal-sliding-window.mlir | Adds E2E attention test case (kvcache + causal + sliding window). |
| mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-kvcache-scale.mlir | Adds E2E attention test case (kvcache + pre-softmax body). |
| mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-splitkv.mlir | Adds E2E attention test case (splitKV). |
| mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-gqa.mlir | Adds E2E attention test case (GQA shapes). |
| mlir/test/Dialect/MIGraphX/ops.mlir | Adds parsing/printing coverage for new ops/attrs and attention variants. |
| mlir/test/Dialect/MIGraphX/invalid.mlir | Adds verifier negative tests for attention feature/shape/operand constraints. |
| mlir/test/Conversion/MIGraphXToTosa/migraphx-to-tosa-preserves-rock-attention.mlir | Ensures MIGraphXToTosa doesn’t rewrite rock.attention regions produced earlier. |
| mlir/test/Conversion/MIGraphXAttentionToRock/attention-to-rock.mlir | Adds conversion tests for --migraphx-attention-to-rock. |
| mlir/test/Conversion/MIGraphXAttentionDecompose/attention-decompose.mlir | Adds host decomposition tests for attention variants and feature combinations. |
| mlir/test/CAPI/mixr_attention.c | Adds C API test that constructs migraphx.attention ops. |
| mlir/test/CAPI/CMakeLists.txt | Builds/links new mlir-mixr-attention-test. |
| mlir/lib/Dialect/MIGraphX/Transforms/MIGraphXTransform.cpp | Implements host-side migraphx.attention decomposition (non-kernel funcs). |
| mlir/lib/Dialect/MIGraphX/Pipeline/Pipeline.cpp | Inserts MIGraphXAttentionToRock into the high-level pipeline before TOSA/Linalg lowering. |
| mlir/lib/Dialect/MIGraphX/Pipeline/CMakeLists.txt | Links in new MLIRMIGraphXAttentionToRock library. |
| mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp | Adds AttentionOp verifier + feature dependency checks. |
| mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosaPass.cpp | Marks Rock dialect / rock.attention recursively legal to preserve nested ops. |
| mlir/lib/Conversion/MIGraphXAttentionToRock/MIGraphXAttentionToRock.cpp | New lowering pass from migraphx.attention to rock.attention for kernel funcs. |
| mlir/lib/Conversion/MIGraphXAttentionToRock/CMakeLists.txt | Adds conversion library target for MIGraphXAttentionToRock. |
| mlir/lib/Conversion/CMakeLists.txt | Adds new conversion subdirectory. |
| mlir/lib/CAPI/Dialect/MIGraphX.cpp | Adds C API builder helper rocmlirMIGraphXAttentionCreate. |
| mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphXTypes.td | Adds AttentionFeatures bit-flag enum attr. |
| mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td | Adds migraphx.yield, migraphx.attention, and applies Elementwise trait to elementwise ops. |
| mlir/include/mlir/Conversion/RocMLIRPasses.td | Declares migraphx-attention-to-rock pass. |
| mlir/include/mlir/Conversion/RocMLIRPasses.h | Exposes new conversion header. |
| mlir/include/mlir/Conversion/MIGraphXAttentionToRock/MIGraphXAttentionToRock.h | Adds public pass declaration header. |
| mlir/include/mlir-c/Dialect/MIGraphX.h | Bumps C API version and declares attention builder + feature bit macros. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Contributor
|
Some general comments/questions:
|
f269c78 to
efe9608
Compare
35df740 to
f269c78
Compare
Drop fallback handling for the legacy kernel attribute in MIGraphX attention passes and migrate affected conversion tests to rock.kernel so behavior and test coverage match the new attribute contract. Made-with: Cursor
cb4bd67 to
55084bb
Compare
Per the workspace-hygiene rule, plan files and scratch notes should live in a git-excluded directory and never be committed. Add the conventional plans/, scratch/, and notes/ paths to .gitignore so per-session working documents (e.g. review reports, TODO outlines) stay out of the tree. Made-with: Cursor
…ffset Previously the verifier accepted currentSeqLen / prefixOffset operands of "rank 1 or 2", but MIGraphXAttentionToRock collapses 4-D Q/K/V to Rock's [B*H, seq, dim] layout without expanding rank-1 [batch] operands across heads. The result was either a confusing rock.attention verifier error (currentSeqLen) or a silent OOB index/load at the gridwise lowering (prefixOffset, which has no shape check on the Rock side). Tighten AttentionOp::verify so the operand shape must equal Q's full leading dims exactly (e.g. [batch] for 3-D Q, [batch, numHeads] for 4-D Q). Producers with a per-batch sequence length must broadcast across heads explicitly via migraphx.multibroadcast before constructing the op; the flattened [B*H] layout remains an internal detail of the kernel-side rock.attention and is materialised by the lowering. Update an attention-decompose test that was using a rank-1 prefixOffset with 4-D Q (which the new rule rejects), add positive coverage for the canonical 4-D + rank-2 [B, H] form in ops.mlir, and add invalid-IR tests for both currentSeqLen and prefixOffset rank-1 + multi-head Q. Made-with: Cursor
The verifier already enforced "slidingWindowSize attribute set ⇒ sliding_window feature set" but not the inverse. With features = "kvcache|sliding_window" and no attribute the host decompose path hits an unreachable assertion (release builds null-deref) and the GPU lowering silently degrades to plain KV-cache masking with no diagnostic because MIGraphXAttentionToRock only forwards a non-null attribute to rock.attention. Either way the user explicitly asked for sliding-window attention and got the wrong answer or a crash. Add a verifyAttrRequiredByFeature helper symmetric with the existing verifyOperandRequiredByFeature and use it to reject sliding_window-without-slidingWindowSize at the migraphx.attention op level. Cover the new path with an invalid-IR test and tidy up a leftover comment block from a deleted helper while we're here. Made-with: Cursor
The existing i8 attention code path goes through migraphx.quant_dot + migraphx.dequantizelinear + migraphx.softmax + migraphx.dot at the graph level, fused into a rock.attention by TosaToRock's pattern matcher. Direct construction of migraphx.attention with i8 Q/K was type-checked as legal (AttentionQKTypes includes I8/SI8) but had no working lowering: the MIGraphXAttentionToRock body builder forced every linalg.generic block arg to the output element type and only emitted arith.*f, and the host AttentionDecompose path emitted migraphx.dot which doesn't accept i8. Make the direct path actually work end-to-end: - AttentionOp::verify accepts migraphx.dequantizelinear in preSoftmaxBody alongside Elementwise-trait ops via a small allowlist, and requires softmaxType to be set explicitly when Q is non-float so the body has a known float target type to dequantize to. - MIGraphXAttentionToRock now derives each linalg.generic block-arg type from its own input (allowing i32 QK from an i8 first GEMM to flow into the body until a dequantize op upcasts it) and lowers migraphx.dequantizelinear to (cast<float>(input) - cast<float>(bias)) * scale using the upstream mlir::convertScalarToDtype helper that MIGraphXToLinalg::castTensor already uses; signedness is read from the original MIGraphX-side operand type because MIXRShapedType::asTensor drops it. - AttentionDecompose (host CPU path used by --verifier clone) emits migraphx.quant_dot with an i32 QK result for integer Q/K, mirroring what the existing TosaToRock-driven path produces. Also fix a pre-existing bug where the softmaxType-vs-elementType convert check compared against Q's original type instead of the body-yielded current type, which meant a redundant convert was being added when the body had already changed precision. Tests: - Update the previously-empty migraphx_attention_i8_qk in ops.mlir to use a meaningful body with migraphx.dequantizelinear and softmax_type = f32. - Add invalid.mlir coverage for the new "softmaxType required when Q is integer" rule. - Add two attention-to-rock.mlir lit tests covering the new arith.sitofp / arith.subf / arith.mulf scalar lowering with and without a dequant zero-point. - Add five pr-e2e/migraphx-attention/mixr-attention-first-gemm-i8*.mlir E2E tests that mirror the existing pr-e2e/attention/* originals (kept alongside) but expressed as a single migraphx.attention with the dequant in preSoftmaxBody. They round-trip cleanly through --verifier clone, confirming the host quant_dot path produces matching numerics. Made-with: Cursor
The verifier accepted any Q whose leading dims were divisible by K's
("queries=qd is divisible by keys=kd"), but for rank-3 Q the (batch,
numHeads) split is unrecoverable from the shape alone: <12x4x8> Q
against <4x8x16> K could mean B=1/H_q=12/H_kv=4 or B=2/H_q=6/H_kv=2 or
B=4/H_q=3/H_kv=1, each producing a different (numHeadsQ, numHeadsKV,
batch) for rock.attention. Both downstream paths fall over on this
input today: the host AttentionDecompose emits an invalid migraphx.dot
(batch dim mismatch) and the GPU MIGraphXAttentionToRock fails Rock's
verifier (numHeadsQ defaults to 1 because getNumHeads only knew the
4D shape).
Tighten AttentionOp::verify so GQA (any leading dim of Q != K's) is
only legal when Q is rank >= 4, where the heads axis is unambiguously
dim 1 by convention. Equal-heads 3D Q (numHeadsQ == numHeadsKV) is
still accepted because rock.attention is symmetric in that case
between "8 heads x 1 batch" and "1 head x 8 batches".
The previous reshape-trace fallback in MIGraphXAttentionToRock::
getNumHeads (which looked at a defining migraphx.reshape to recover
the 4D heads count from a 3D Q) is now dead code: with the verifier
rule above, a 3D Q only reaches this lowering when numHeadsQ ==
numHeadsKV, where the choice between "real heads" and "1 head, big
batch" is numerically irrelevant. Drop the trace and the
attention_gqa_3d unit test that exercised it; add a negative
invalid.mlir test confirming the new rule's diagnostic.
Add two new pr-e2e/migraphx-attention/ E2E tests
(mixr-attention-gqa-bias, mixr-attention-gqa-scale) mirroring the
existing pr-e2e/attention/* GQA tests but expressed as a single
migraphx.attention with a preSoftmax body. Originals are kept; this
broadens coverage of the direct migraphx.attention path through the
GQA lowering chain.
Made-with: Cursor
The verifier was checking the preSoftmaxBody block argument *count* but not their *types*: a body declaring block-arg 0 as <2x4x99xf16> against a real QK shape of <2x4x16xf16> would slip through and fail far downstream during lowering. The splitKV path also only checked the rank of preSoftmaxElemWiseInputs, not their dim sizes. With the M5 i8 support landed earlier in this branch, block-arg 0's element type now also depends on whether Q is integer (the body sees the i32 quant_dot output and is expected to dequantize it), which the old verifier did not enforce either. Consolidate the body checks into a single block that: - Computes the expected QK shape once (with splitKV inflation when the splitkv feature is set), and the expected QK element type (Q's type for float Q, i32 for integer Q). - Walks the body once with block.without_terminator(), confirming each op is in the existing migraphx.elementwise / dequantizelinear allowlist. - For an empty body: requires a bare yield and zero block arguments. - For a populated body, in order: arg count matches, block-arg 0 has exactly the QK shape and element type, each block-arg i+1 matches preSoftmaxElemWiseInputs[i] and that input's shape matches the QK shape, and the yield's value shape matches the QK shape (element type is free so the body can dequantize / convert before softmax). This is a verifier-only change. All previously-valid IR still verifies (confirmed by the targeted attention sweep: 772 tests, 743 pass, 0 unexpected failures). Update one existing invalid.mlir test whose old splitKV-rank-only diagnostic is now subsumed by the more specific block-arg-0 check, and add four new negative tests covering wrong block-arg shape, wrong block-arg element type for i8 Q (signedness), block-arg i+1 vs input type mismatch, and input shape vs QK shape mismatch. Made-with: Cursor
The host AttentionDecompose path defaulted softmaxElemType to the value entering softmax (Q's element type for an empty body, the body's yielded type otherwise) while rock::gridwise_attention_accel defaulted to V's element type. When Q.elemType != V.elemType and softmaxType was unset, the two paths would silently softmax in different precisions and --verifier clone would diverge. The previous M5-era check guarded the narrower "Q is integer-typed" subset of this footgun but missed the all-float-but-mismatched-precisions case (e.g. bf16 Q/K with f32 V). Replace the integer-Q-specific check with a more general one in AttentionOp::verify: when the value entering softmax doesn't already have V's element type, the producer must set softmaxType explicitly so the lowering can insert convert ops on either side of softmax. The new rule is strictly stronger -- it still rejects all the cases the old rule rejected (integer Q with empty body still fails because i32 != V), correctly accepts the case the old rule wrongly rejected (integer Q with a body that dequantizes to V's element type), and additionally catches the float-but-mismatched case the old rule missed. The check moved to after the body validation so it can read the yielded element type. Align the host AttentionDecompose default with rock's gridwise default by reading op.getSoftmaxType().value_or(vType.getElementType()), mirroring rock::gridwise_attention_accel:2147 line for line. With the new verifier rule above this is provably safe -- either softmaxType is explicitly set (and value_or returns it) or the value entering softmax already matches V (so the default lands on the same value the old "qkCurrentElemType" default would have picked). Add a TODO at the host's second GEMM noting that migraphx.dot accumulates in the operand element type rather than promoting to the softmax type the way the GPU mfma path does. For long sequences this can produce slightly less accurate CPU reference values than the GPU; widening the dot's accumulator (or splitting into f32 partial sums + downcast) would close the gap. Update the M5-era invalid-IR test whose old diagnostic message changed and add two new negative tests: - Float Q (bf16) and V (f32) with no body and no softmaxType -- the C2.4 footgun, now rejected at the verifier. - Body that yields a different element type than V (f32 vs f16) with no softmaxType -- same rule, populated-body branch. Add a positive E2E test mirroring the new shape: bf16 Q/K with f32 V and an explicit softmax_type = f32, batch=1 heads=4. Both the host CPU reference and the GPU kernel agree within bf16 precision (verified via --verifier clone). Made-with: Cursor
The pr-e2e/attention/ directory has a family of bare matmul+softmax+matmul tests with non-power-of-2 padded shapes (head_dim=3, seq_len=7) and various pre-softmax bodies (none, bias, scale, scale+bias, scale+exp+bias, multi-op tree, full-shape constant scale). These tests cover the TosaToRock attention pattern matcher; they did not exercise the migraphx.attention direct path. Add nine pr-e2e/migraphx-attention/ counterparts that express the same shapes and numerics as a single migraphx.attention with a preSoftmax body. Originals are kept; the new tests broaden coverage of the direct lowering path through MIGraphXAttentionToRock and the host AttentionDecompose: - mixr-attention-square-3d (mirror of mixr-attention.mlir, square 3D 64x64 shape, no body) - mixr-attention-padded (3D 7x3 shape, no body) - mixr-attention-padded-bias (1-input add body) - mixr-attention-padded-scale (1-input mul body) - mixr-attention-padded-scale-and-bias (2-input mul+add body) - mixr-attention-padded-const-scale (full-shape migraphx.literal scale, shape rewritten from the original 1xf32+multibroadcast to a full-shape literal so the materialised scale matches the QK shape) - mixr-attention-padded-scale-cross (cross-attention seqQ=128, seqK=27, perf_config forwarded onto migraphx.attention) - mixr-attention-padded-scale-bias-exp (3-op body: mul, exp, add) - mixr-attention-padded-complex-tree-elemwise (4-op tree body combining auxiliary inputs at leaf level then scaling/biasing the QK output) The padded-scale-bias-exp test exercises migraphx.exp in the body. That op already had the Elementwise trait (so the verifier accepted it), but the MIGraphXAttentionToRock scalar lowering only handled mul/add/sub/neg/dequantizelinear. Add a single math.exp scalar lowering plus the matching MLIRMathDialect link library and dependentDialects entry; all other ops in the new tests already had scalar lowerings. All nine new tests pass end-to-end on GPU and verify against the host CPU reference (--verifier clone). Targeted attention sweep is unchanged otherwise: 782 tests, 753 pass, 4 expected fails (pre-existing GPU-required), 25 unsupported, 0 unexpected failures. Made-with: Cursor
Two follow-ups from the consolidated review:
R.4: the host applySlidingWindowMask computes
lowerBound = currentSeqLen + (-windowSize)
without clamping to zero. It produced correct numerics today only
because tosa.greater (and migraphx.greater) treat their signed integer
operands as signed, so a negative lower bound compared against a
non-negative column index returns false and no spurious mask is
applied. The GPU side (rock::gridwise_attention_accel:1844) explicitly
emits arith.maxsi against zero. Mirror that here with a where(0 >
lowerBound, 0, lowerBound) clamp so the host IR no longer depends on
the signed-comparison convention and the two paths use the same
documented invariant.
C/N: MIGraphX.td documented the kvcache mask predicate as
"key_index >= currentSeqLen"
but both host (applyKVCacheMask) and GPU
(GridwiseGemmToBlockwise.cpp:1318 ugt mIndex, currentSeqLen) actually
use strict greater. Update the doc text to match the implementation
and clarify that currentSeqLen is the index of the last valid key
position (range [0, currentSeqLen] inclusive), matching PyTorch SDPA
and FlashAttention. Drop a stale "TODO: verify that negative prefixes
work" comment in the same block whose answer the surrounding
prefix_offset description already gives ("the offset can be positive
or negative").
Strip the trailing blank line at the end of invalid.mlir flagged by
git diff --check.
Made-with: Cursor
The MIGraphXAttentionToRock body lowering only mapped 5 migraphx ops
to scalar arith/math (mul/add/sub/neg/exp/dequantizelinear), but
AttentionOp::verify accepted any op carrying the Elementwise trait
plus migraphx.dequantizelinear. Anything outside the lowering's small
set (e.g. migraphx.div in a body) passed the verifier and then failed
deep in the conversion pass with an opaque "unsupported migraphx op"
diagnostic far from the user's migraphx.attention. Also, the explicit
five-op subset constrained what producers could put in the body
without good reason.
Extend lowerMIGraphXElementwiseToScalar to cover the full set of
scalar-lowerable migraphx elementwise ops, mirroring
MIGraphXToLinalg's ElementwiseConverter / GenericElementwiseOpConverter
coverage so host CPU and GPU body lowerings stay numerically aligned.
The new mapping is:
Binary float : add (addf), sub (subf), mul (mulf), div (divf),
pow (math.powf)
Unary float : neg (negf), abs (math.absf), ceil (math.ceil),
floor (math.floor), exp (math.exp), log (math.log),
sqrt (math.sqrt), tanh (math.tanh), erf (math.erf)
Composed : recip (1/x), relu (max(0,x)), sigmoid (1/(1+exp(-x)))
Selection : where (i8 cond -> arith.select via i1 cast)
Cast : convert (mlir::convertScalarToDtype, signedness-aware)
Quant : dequantizelinear (unchanged)
Tighten AttentionOp::verify to use an explicit allowlist matching the
new lowering set exactly. The previous Elementwise-trait test let
greater/equal/clip/rsqrt slip through unsupported; the explicit
allowlist makes the verifier-vs-lowering contract crisp and adding a
new body op is a one-line change in two coupled places. Update the
diagnostic to point at lowerMIGraphXElementwiseToScalar so the next
maintainer knows where to look.
Tests:
- Update two pre-existing invalid.mlir tests whose assertions
referenced the old "must only contain elementwise migraphx ops"
wording; the cases (migraphx.dot and migraphx.reduce_sum in body)
still reject, just under the new message.
- Add attention_extended_body_ops to attention-to-rock.mlir exercising
div/pow/neg/abs/exp/log/sqrt/tanh/erf/recip/relu/sigmoid/where/
convert in a single body, with FileCheck patterns confirming each
maps to its expected arith/math scalar op.
Verification: targeted attention sweep is unchanged at 782 tests,
753 pass, 4 expected fails (pre-existing GPU-required), 25
unsupported, 0 unexpected failures.
Made-with: Cursor
Two contracts were duplicated across the migraphx.attention lowering chain: the QK output element type rule (Q.elemType for float Q, i32 for integer Q from a quantized first GEMM), and the closed set of body ops that MIGraphXAttentionToRock::lowerMIGraphXElementwiseToScalar can scalar-lower. The element-type rule was open-coded in three places (AttentionOp::verify, MIGraphXTransform's host AttentionDecompose, and rocmlir-gen's attention generator) and the body allowlist was both a local lambda in the verifier and an implicit dispatch table in the lowering, with only a comment cross-referencing them. Add a header-only mlir/include/mlir/Dialect/MIGraphX/IR/AttentionUtils.h exposing two small inline helpers: - computeAttentionQKElemType(qElemType, ctx): encodes the QK-output element type rule. - isAllowedInPreSoftmaxBody(op): the closed allowlist of body ops with a matching scalar lowering. Refactor the three element-type call sites (verifier / host decompose / rocmlir-gen) and the verifier allowlist to consume the shared helpers, removing the local lambdas and ternaries. The header's docstring calls out the lock-step coupling between isAllowedInPreSoftmaxBody and the lowering's dispatch table so future maintainers know to update both places when adding a new body op. Items NOT extracted (deliberately): - expectedQKShape: the verifier and host decompose work from different inputs (pre- vs post-splitKV-reshape types) so unifying would obscure the splitKV reshape that's already happening. - getNumHeads vs TosaToRock's getNumHeadsGQA: same name, different operations (read dim 1 of a 4D migraphx type vs walk back through tensor.collapse_shape on the tosa side). - Contiguous-strides construction: single use, no duplication. - isMixrUnsignedInt: single use; the inline form is clearer at the call site. Verification: targeted attention sweep is unchanged at 782 tests, 753 pass, 4 expected fails (pre-existing GPU-required), 25 unsupported, 0 unexpected failures. Pure NFC -- no behavior change, just shifts the rule definitions to one place. Made-with: Cursor
Three refactors that don't change behaviour: - Validate splitKV up-front in AttentionOp::verify so result/LSE/QK shape construction can use a single validated effectiveSplitKV value, instead of computing it twice with subtly different gates (>1 vs. just present). - Drop the unreachable `splitKV <= 0` check: the orphan-attr path now rejects splitKV-without-feature, and the splitkv-feature path requires splitKV > 1, so a zero/negative value is always caught earlier with a more specific diagnostic. - Replace three open-coded contiguous-stride loops in MIGraphXTransform's AttentionDecompose with the existing makeContiguousType helper. Co-authored-by: Cursor <cursoragent@cursor.com>
…ndow clamp Two small follow-ups for the migraphx.attention preSoftmaxBody and sliding-window mask lowering: - Assert in AttentionToRockPattern's body builder that any op the verifier admits via isAllowedInPreSoftmaxBody is also lowered by lowerMIGraphXElementwiseToScalar. The two lists were already documented as a lock-step pair, but nothing was catching divergence beyond a structured runtime error. The assertion fires in debug builds the moment the dispatcher is missing a case the verifier approves; release builds continue to surface the same "unsupported migraphx op in preSoftmaxBody" error. - Add an E2E test pinning currentSeqLen to [0, 2] with slidingWindowSize = 4 so the unclamped lowerBound formula (currentSeqLen - windowSize) is always negative. Both host and GPU apply max(0, .) and produce the same numerics; --verifier=clone guards against a future regression where one side stops clamping. Co-authored-by: Cursor <cursoragent@cursor.com>
Two related verifier additions that close gaps the previous review flagged: - Forbid features = "splitkv | sliding_window" outright. splitkv reshapes the body to operate on per-chunk [seqK / splitKV] columns while sliding_window's lower bound lives in absolute K-position space, and the gridwise lowering's mask uses that absolute index too. Nothing in the current pipeline reconciles those two views and no test exercises the combination, so reject it explicitly rather than have host and GPU silently disagree. Adds a new verifyFeatureMutualExclusion helper for the rejection. - Spell out that 'sliding_window' requires currentSeqLen, on top of its existing dependency on 'kvcache'. Today the kvcache requirement catches the missing operand transitively; making the rule explicit protects the sliding-window path if the kvcache linkage ever changes. Add an invalid test for the splitkv + sliding_window rejection. Co-authored-by: Cursor <cursoragent@cursor.com>
Two small cleanups in the migraphx.attention lowering and host decomposition: - Drop convertMIXRToTensor in MIGraphXAttentionToRock. It was a one-line wrapper around migraphx::AsLogicalShapeOp::create used in exactly one call site; the rest of the file already calls the op builder directly, so inline the only caller. - Tighten the host LSE reshape in AttentionDecompose. Pull the target shape directly from the op's lse result type and name the current element type, so the data flow is clearer than the earlier "build a vector from the verifier-output shape and walk it twice" pattern. Co-authored-by: Cursor <cursoragent@cursor.com>
Two correctness bugs that surface when splitKV interacts with the preSoftmax body or attention masks: - AttentionToRockPattern unconditionally set preSoftmaxHasSplitKVTransforms = false. The verifier requires preSoftmax inputs to live in split space ([B*H*splitKV, SeqQ, SeqK/splitKV]), but without this attribute the gridwise lowering leaves the GEMM0 output in the un-split shape and the linalg.generic body iterates the two operands in disagreeing index spaces. Set the attr to true when splitKV > 1 and the body has inputs. - Host AttentionDecompose's mask helpers (causal / kvcache / sliding window) used qkShape[rank-1] as their column iota, which is the per-chunk seqK / splitKV when the body operates in split space. Masks compare against the original currentSeqLen / row index, so per-chunk column indices mask the wrong keys for non-zero split chunks. Plumb splitKV through createBroadcastColIndices and emit splitIdx * seqKPerSplit + localCol when splitKV > 1, matching the GPU's absolute mIndex logic. Add an IR-level FileCheck test exercising the global-index path for splitkv + kvcache, and document the still-loose splitkv-scale e2e threshold (the host's per-chunk softmax decomposition diverges from GPU's combined-via-LSE recombination irrespective of these fixes). Co-authored-by: Cursor <cursoragent@cursor.com>
Five verifier rules that previously slipped past the migraphx-side
verifier and surfaced as opaque downstream errors (or, worse, invalid
IR like arith.addf-on-i32):
- Reject integer arithmetic in the preSoftmax body (G3). Body block
arg 0 is i32 for integer-Q attention (output of quant_dot), but the
scalar lowering only emits arith.{add,mul,sub,div,...}f. Reject any
body op other than dequantizelinear / convert / where whose operands
are integer; the body must dequantize or convert before arithmetic.
- Restrict softmaxType to {f16, bf16, f32} (G4). f64 and other exotic
floats verified at the migraphx layer but tripped a less helpful
rock-internal verifier later.
- Constrain attention result to AttentionVTypes and lse to
{f16, bf16, f32} (G5). The TableGen had AnyMIXRShaped on both, so
e.g. an i32 attention result verified happily.
- Require Q rank exactly 4 when GQA is active (G6). Rank 3 was
already rejected; rank 5+ now is too because neither the host
broadcastForGQA nor the GPU getNumHeads detector handle higher ranks
- they assumed rank 4 and produced wrong index spaces or fell back
to numHeads = 1.
- Require Q / K / V rank >= 3 (G7). Rank 2 attention slipped past the
migraphx verifier and failed to legalise through rock-gridwise with
an opaque "failed to legalize rock.attention" error. Rock requires
a leading batch dim, so reject rank 2 here with a clear diagnostic.
The previous commit already made splitkv + sliding_window work
correctly (the global-col-index fix covers both kvcache and
sliding_window masks under splitKV), so drop the now-unnecessary
splitkv vs sliding_window mutual exclusion and its helper.
Add invalid tests for each new rule.
Co-authored-by: Cursor <cursoragent@cursor.com>
The earlier comment blamed AttentionDecompose's per-chunk softmax for the host-vs-GPU divergence on this test. After comparing a tiny hand-decomposed splitkv attention (the same pattern develop's pr-e2e/attention/mixr-attention-flash-decoding.mlir uses) against this branch's migraphx.attention path, both produce the same per-element divergence and the GPU output's second split chunk is partially uninitialised (NaN) on the smaller shape. The bug is in rock::gridwise_attention_accel rather than in this branch. Update the in-test comment to point at the right place so a future fix targets the right pass; thresholds stay loose until the rock kernel writes both partials cleanly. Co-authored-by: Cursor <cursoragent@cursor.com>
User pointed out that splitkv testing needs gridSize * splitKV workgroups. Confirmed via the rock-side formula computeGridSizeAttentionGemmElmtGemm (gridSize = (gemm0N / NPerBlock) * gemm0G * splitKV) - the launch is correct, both per-workgroup coords are unique (g, n_block, split_block) via makeGxNGridLayout, and the rock pipeline does launch one WG per chunk. The bug is elsewhere in gridwise_attention_accel: chunk 1 keeps stale memory even though its workgroup runs. Update the comment to reflect this so a future rock-side fix targets the right code path. Co-authored-by: Cursor <cursoragent@cursor.com>
User asked whether splitkv needs an output prefill. Confirmed via
deep dive: prefill alone wouldn't fix this. The actual bug is in
rock::gridwise_attention_accel.postProcessFirstGemm
(GridwiseGemmToBlockwise.cpp:1636) - the ThreadwiseReadIntoOp that
reads preSoftmaxElemWiseInputs gets extra indices
{g_block, m_block, n_block, tid}, missing split_block. The body
input is sized [B*H*splitKV, ...] but only the (g, n) coords of the
gemm0 grid are threaded into the read address, so both chunks of
every (g, n) pair read the same body-input slot. Chunk 1 then
computes with the wrong inputs and the output write
(also missing split_block in gridCoordsGemm1) lands on chunk 1's
slot of the output - which holds whatever uninitialised memory
came from the harness alloc.
Confirmed:
- grid is sized correctly (gridSize * splitKV) and bid→(g, n, split)
mapping is unique.
- splitkv without body passes 0.0005 threshold.
- splitkv with a body that has zero extra inputs (just a convert)
passes 0.0005 threshold even at our shape.
- splitkv + body + extra inputs reproduces the divergence identically
on a hand-written develop-style decomposition.
Update the in-test comment to reflect this finding so the rock-side
fix targets the right code path.
Co-authored-by: Cursor <cursoragent@cursor.com>
`createSplitKVTransformsForGemm0Out` returns a stack pointing from body-input shape `[B*H*splitKV, SeqQ, SeqK_chunk]` (top) down to gemm0-buffer shape `[B*H, SeqQ, SeqK]` (bottom). The previous prepend-under-`linalgGridSubTileMaps` left the composed chain ending in gemm0-buffer space, so when `postProcessFirstGemm` then composed `linalgToOtherInputMaps` for each `preSoftmaxElemWiseInput` the bottom no longer matched the user input memref. The resulting `rock.threadwise_read_into` produced wrong addresses for the extra body inputs whenever splitKV > 1. The bug was silent for bodies with no extra inputs (e.g. develop's flash-decoding test), because the composed chain is only consumed in the `for (genOpInput)` loop in `postProcessFirstGemm`. With an extra input (e.g. our splitkv-scale test) chunk 1 produced large numerical divergence and unreliable LSE. Invert `splitKVTransforms` before prepending so the composed chain ends in body-input space, which is what `postProcessFirstGemm` expects for each `preSoftmaxElemWiseInput`. Also restore tight `RMS_threshold = 0.002 / relDiff_threshold = 0.0005` on `mixr-attention-splitkv-scale.mlir` (was loosened earlier as a workaround) and drop the long workaround comment. Co-authored-by: Cursor <cursoragent@cursor.com>
Locks in the transform-stack direction the previous commit fixed for
`postProcessFirstGemm`. The test feeds `rock-gridwise-gemm-to-blockwise`
a `rock.gridwise_attention_accel` op with `splitKV > 1` AND a body that
consumes an extra `preSoftmaxElemWiseInput`, then checks that the
resulting `rock.threadwise_read_into` for the user input ends in two
leaf transforms going gemm0-shape -> body-shape:
Merge{splitKV, seqK_chunk} ["seqK"] -> ["splitKV", "seqK_chunk"]
Unmerge{batch, splitKV} ["batch", "splitKV"] -> ["batch"]
These are the inverse of `createSplitKVTransformsForGemm0Out`'s output
and only appear after the inversion in the previous commit. Reverting
the inversion flips both factor orderings and the test fails.
This was the gap that allowed the underlying bug to land: the existing
splitKV tests in `gridwise_attention_accel_lowering.mlir` use empty
`preSoftmaxOps`, and the existing body tests use `splitKV = 1`, so no
test exercised the buggy code path before.
Co-authored-by: Cursor <cursoragent@cursor.com>
Three small simplifications surfaced in the post-fix branch review:
1. `applySlidingWindowMask`: drop the
`static_cast<uint64_t>(-windowSize)` + computed `signedSemantics`
dance. `currentSeqLen`'s element type is restricted to `[I32, SI32]`
by the op definition, so a plain
`APInt(width, -windowSize, /*isSigned=*/true)` is always correct
and reads more directly.
2. `AttentionToRockPattern::matchAndRewrite`: replace the
`&& !srcRegion.empty()` defensive check with an assertion.
`SingleBlockImplicitTerminator` on `migraphx.attention` guarantees
the body has a block, so the runtime check was always true.
3. `AttentionToRockPattern::matchAndRewrite`: extract a
`getCollapseToLastDimReassoc(rank)` helper alongside the existing
`getLeadingDimReassoc`, and use it for the LSE expand-shape (rank
-> {{0..rank-2}, {rank-1}}). Replaces an inline reassociation
builder at the LSE expand site.
No behavior change. All attention lit/E2E tests pass; full
check-rocmlir is green (1370/1370).
Co-authored-by: Cursor <cursoragent@cursor.com>
…fier and decompose Verifier (`AttentionOp::verify`): - Factor `qBatch + (splitKV if > 1) + trailing` shape construction into a `makeAttnShape` helper, used at 3 sites (result, LSE, body QK shape). - Factor the `size != size || !std::equal(...)` shape mismatch check + diagnostic format into a `checkAttnShape` helper, used at 2 sites (result, LSE). - Hoist `seqK = kShape[kRank - 1]` once at the top of `verify()` and reuse for the splitKV-divisibility check, the QK shape construction, and the sliding-window max-seq-len bound (was duplicated 3x). Net: ~30 lines shorter, consistent diagnostic format, and the splitKV-inflation rule now lives in one place that the next reader can scan in 4 lines instead of 4 inlined blocks. Host decompose (`MIGraphXTransform.cpp`): - Replace the manual permutation loops in `splitKVReshapeK` with the upstream `applyPermutation` helper from `mlir/Dialect/Utils/IndexingUtils.h`. Drops 2 explicit loops and an intermediate `kSplitStrides` ArrayRef. - Add a small `createBroadcastIntScalar` helper that wraps the APInt + DenseElementsAttr + createBroadcastScalar boilerplate. Used at 2 sites (`createBroadcastColIndices`'s `seqKPerSplitConst` and `applySlidingWindowMask`'s `zeroI32` clamp). - Tighten `Type si32 = getSi32Type(...)` to `IntegerType si32 = ...` across the file (`getSi32Type` already returns `IntegerType`); removes a `cast<IntegerType>` at the zero-clamp call site. Drive-by: clang-format catches a stray blank line in `mlir-c/Dialect/MIGraphX.h` and a line-break in `RocMLIRPasses.td`'s `dependentDialects` list, both pre-existing in the branch. No behavior change. All attention lit/E2E tests pass; full check-rocmlir is green (1370/1370). Co-authored-by: Cursor <cursoragent@cursor.com>
Five attention findings from a follow-up review: 1. (High) Reject dynamic dims. The verifier and downstream lowering do static shape arithmetic (% on seqK, leading-dim collapse, body QK shape construction). Allowing dynamic dims through migraphx.attention produced silently-broken IR. Add an up-front static-shape check on queries/keys/values/result/lse/currentSeqLen/prefixOffset and each preSoftmaxElemWiseInput, mirroring MultiBroadcastOp::verify's existing convention for this dialect. 2. (High) Require Q.elemType == K.elemType. The op definition let Q and K independently pick from `[F32, F16, BF16, I8]`. Mixed Q/K element types passed the verifier and fed into a `migraphx.dot` with mismatched operand types or routed via Q's element type alone in the GEMM-op selection. Reject mixed Q/K up front. 3. (Medium) Widen the host's second GEMM to softmaxType. The CPU reference's `migraphx.dot(softmax, V)` accumulated in V's element type while the GPU MFMA path keeps gemm1's accumulator at softmaxType (typically f32) and downcasts only at the end. For long sequences the CPU reference would diverge from the GPU. Convert V to softmaxType, run the dot in softmaxType, downcast the result to V's element type. The existing TODO at the second-GEMM site is now resolved and removed. 4. (Medium) Require lse.elemType == effective softmaxType. Both host and GPU compute LSE intermediates (reduce_sum, log, max, add) in the softmax type. Allowing an LSE result wider than softmaxType would silently round-trip through narrower intermediates. With softmaxType absent the effective type is V's element type; in both cases the LSE output must match. 5. (Low) Tighten kvcache-scale-lse threshold. After (3) the pre-existing `RMS_threshold = 0.03 / relDiff_threshold = 0.2` workaround on `mixr-attention-kvcache-scale-lse.mlir` is no longer needed; the test now passes at `RMS_threshold = 0.002 / relDiff_threshold = 0.005`, matching sibling attention tests. Drive-bys: - Polish the rank-2-rejection comment that read as if rank-2 still slips past (it does not). - Fix the long-standing `stlide` -> `stride` typo in MIGraphXTypes.td. Tests: - 4 new invalid tests: dynamic Q, dynamic preSoftmax input, Q/K type mismatch, lse vs softmax mismatch. - Update 2 ops.mlir tests and 1 attention-to-rock test to add the now-required `softmax_type = f32` for V.f16 + LSE.f32 cases. - Replace `decompose_lse_type_convert` with `decompose_widened_second_gemm` exercising the widening path (V f16 + softmax_type=f32 + LSE f32 emits a convert before and after the second dot). Verification: full check-rocmlir passes (1370/1370, 0 fail, 4 XFAIL). Co-authored-by: Cursor <cursoragent@cursor.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Currently MIGraphX passes series of decomposed ops for attention. Then inside TosaToRock, rocMLIR pattern matches to find out attention and what kind of attention variant it is.
It is increasingly becoming more difficult on rocMLIR side to do pattern matching. It is much better for migraphx graph compiler to do this.
Therefore this PR introduces
migraphx.attentionop where it takes "features" attribute to describe what kind of attention variant it is.Technical Details
migraphx.attentionop which has similar semantics as rock.attention op.Test Plan
Merge plan :
I plan to break down this large PR into several smaller ones but keeping this large draft for feedback on overall structure.
Note to MIGraphX folks
Looks at files in
mlir/test/fusion/pr-e2e/migraphx-attention/to see examples of attention kernels.e.g. attention kernel with sliding window + kvcache + casual masking looks like following
SplitKV =2 with LSE and "preSoftmaxBody"