Skip to content

Conversation

@keshavvinayak01
Copy link
Contributor

@keshavvinayak01 keshavvinayak01 commented Oct 28, 2025

This PR adds comprehensive support for converting PyTorch's torch::hop_flex_attention operation to IREE's linalg_ext.attention operation. The implementation includes proper handling of score modification functions, mask functions, and correct softmax computation using base-e exponential (exp) instead of base-2 exponential (exp2).

Changes

  • Added FlexAttentionOpConversion pattern to convert torch.hop_flex_attention to linalg_ext.attention
  • Modified AttentionOp to pass iteration indices as block arguments to the region
  • Enables score modification functions to access batch, head, query sequence, and key sequence indices
  • Added lit tests for both LSE (log-sum-exp) and non-LSE cases
  • Simplified region handling by using block arguments instead

Fixes:

  • PyTorch's flex attention already supplies the correct scale for base-e softmax. This commit fixes the computation to use exp instead of exp2:
  • The use_exp2 flag is mostly unused in dialect conversions and passes, I presume it's used as a KernelOption. The changes here will not modify the default behavior.

Testing:

  • I ran the entire flex_attention_hop implementation with randomised input tensors, (Also see torch-mlir) through aot.export and compared against eager mode, and I noticed no accuracy losses (On CPU)
  • Command: iree-compile --iree-stream-partitioning-favor=min-peak-memory --iree-hal-target-device=local --iree-hal-local-target-device-backends=llvm-cpu --iree-llvmcpu-target-triple=x86_64-pc-linux-elf --iree-llvmcpu-debug-symbols=false ../torch-mlir/exported_ir.mlir -o ./flex_attention_cpu.vmfb

@keshavvinayak01 keshavvinayak01 changed the title Initial support for aten::flex_attention and rewrite to linalgext.onl… [WIP]Initial support for aten::flex_attention and rewrite to linalgext.online_attention Oct 28, 2025
@keshavvinayak01 keshavvinayak01 changed the title [WIP]Initial support for aten::flex_attention and rewrite to linalgext.online_attention [LinalgExt] Initial support for aten::flex_attention and rewrite to linalgext.online_attention Nov 8, 2025
Instead of requiring IREE::LinalgExt::IndexOp inside the region body,
pass iteration indices (b, h, m, n) directly as block arguments. This
simplifies lowering passes.

Changes:
- Update OnlineAttentionOp/AttentionOp verifiers to accept 1 or 5 block args
- Modify applyPostQKMatmulElementwise to provide indices as block arguments
- Update FlexAttentionOp lowering to declare 5 block arguments
- Document region signature in LinalgExtOps.td

The region now receives: (score: f32, b: index, h: index, m: index, n: index)
instead of requiring IndexOp operations to query indices.

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
-  PyTorch’s attention already supplies the correct scale for the base‑e softmax. All the following files are changed to support exp computation instead of exp2.
1. ConvertTorchUnstructuredToLinalgExt: FlexAttentionOpConversion pattern passes use_exp2 = false, which can be used correctly in decomposition.
2. AggregatedOpInterfaceImpl: accepts the use_exp2 flag as an attribute for decomposition and calls computeSubAndExp accordingly.
3. LinalgExtOps.td: Added getUseExp2AttrStr() to both online_attention and attention.
4. ReshapeFusion.cpp: createCollapsedOP was recreating the attention op and stripping all attributes before StripCompilationPass. Change was necessary to support correct decomposition.

Signed-off-by: Keshav Vinayak Jha <[email protected]>
@keshavvinayak01 keshavvinayak01 force-pushed the users/keshavvinayak01/flex-attention branch from 28a772f to f860ccc Compare November 12, 2025 10:05
@keshavvinayak01 keshavvinayak01 marked this pull request as ready for review November 12, 2025 10:05
@keshavvinayak01 keshavvinayak01 changed the title [LinalgExt] Initial support for aten::flex_attention and rewrite to linalgext.online_attention [LinalgExt] Initial support for aten::flex_attention and rewrite to linalg_ext.attention Nov 12, 2025
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
@hanhanW
Copy link
Contributor

hanhanW commented Nov 12, 2025

I can review after you fix all build failures. It looks not ready to me when CI is so red.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High level comment:

I did not point out all the code that has the issue, please do self-review before you request the next review. Thanks!

- General formatting
- Removed hardcoded exp2 decomposition attribute

Signed-off-by: Keshav Vinayak Jha <[email protected]>
@hanhanW
Copy link
Contributor

hanhanW commented Nov 14, 2025

Please click Re-request item or ping me on discord when it is ready for the review. Thanks! :)

image

Signed-off-by: Keshav Vinayak Jha <[email protected]>
1. ConvertTorchUnstructuredToLinalgExt.cpp: minor formatting stuff, and removed a redundant comment.
2. unstructured_linalg_ext.mlir: Modified lit test with smaller score function for simplicity; Added warning check incase lse output is expected.
3. AggregatedOpInterfaceImp.cpp: AttentionOp doesn't need the exp2 modification. Removed redundant comments.
4. LinalgExtOps.td: AttentionOp doesn't need getUseExp2AttrStr() attribute.
5. decompose_aggregate_op.mlir: Added another online_attention op variant that passes use_exp2=False in the decompositionConfig and CHECKs that math.exp2 ops are not created.
6. DecomposeAttention.cpp: Extract useExp2 bool from the pass option and attaches a new decompositionconfigattr with the relevant value.
7. Passes.td: iree-linalg-ext-decompose-attention pass now accepts useExp2 as a pass option (Defaults to true to preserve legacy behavior)
8. ReshapeFusion.cpp: I had forced the collapsedOp to again attach the decompositionConfig, don't need this change anymore.

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Copy link
Collaborator

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this PR needs to be split-up. There is a signficant change to the ext.attention operation that this is trying to do with the change to the block arguments of the region. I think that is a worthwhile change, but we need to make sure this done in a way that meshes well with all the transformations surrounding the operation. As it is done today this isnt done. Lets split this PR into

  1. Changes to the ext.attention operation
  2. Lowering of flex attention operation to ext.attention operation.

static constexpr int kNumModificationIndices = 4;

// Makes it convenient to pass around commonly used types.
struct TypeInfo {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we want to do this. This is an unnecessary indirection, and potentially this is creating new types in the global context without needing it.


// Check if the element type is a float.
Type elementType = queryType.getOptionalDtype();
auto floatType = dyn_cast<FloatType>(elementType);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note. I dont know the torch operator here, but if this operation is expected to have a float type only you can add that check to the verifier of the torch op, and avoid checking it here. Just makes the code less verbose. If needed you can assert that it is float type here.

Copy link
Contributor Author

@keshavvinayak01 keshavvinayak01 Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I'll move it to the ::verify of the torch op itself. I think a lot of checks here can be moved to the verifier.

Value outputInit = arith::getIdentityValue(arith::AtomicRMWKind::addf,
floatType, rewriter, loc,
/*useOnlyFiniteValue=*/true);
Value outputTensor = tensor::SplatOp::create(rewriter, loc, outputInit,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure we create splat operations. If this is for the output tensor maybe create a linalg.fill operation? Thats what the compiler handles best. In any case as long as when you compile the splat gets converted to a fill later on, its a moot point.

rewriter, loc, torchType.toBuiltinTensor(), torchTensor);
}

void setTypeInfo(PatternRewriter &rewriter, FloatType floatType) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I am usually not a proponent of member functions. This seems to be just a helper method. You can move this and others out of the class as utility functions.

computeDynamicSizes(rewriter, loc, maskShape, maskDynSizes, builtinQuery,
builtinKey);

Value maskTensor = tensor::EmptyOp::create(rewriter, loc, maskShape,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the mask float type? Thats a lot of storage for just a boolean?


Region:
The region body receives the following block arguments:
- score: the computed score value from Q @ K.T (element type of output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can you make this such that the index operands come first and then the score....

Not Nit: The way the attention operation is setup it can increase dimensionality of the operation. So there isnt necessarily one batch dimension, there could be multiple batch dimensions. Same for head/m/n etc. That probably needs to be accounted for in the change. Just having a single batch dimension is not going to work.

@keshavvinayak01
Copy link
Contributor Author

@MaheshRavishankar Thanks for the review, I'll incorporate some of the suggested changes.
As for the PR bifurcation, It's a bit difficult to isolate the lowering and the required changes. The lowering relies on block arguments being supported which is a change to the AttentionOp interface itself.
Could you be more specific about how you see this split being done?

@MaheshRavishankar
Copy link
Collaborator

@MaheshRavishankar Thanks for the review, I'll incorporate some of the suggested changes. As for the PR bifurcation, It's a bit difficult to isolate the lowering and the required changes. The lowering relies on block arguments being supported which is a change to the AttentionOp interface itself. Could you be more specific about how you see this split being done?

Yeah, first you should make a change to the AttentionOp itself which will get the block arguments support. Then you can make the changes to the lowering? The part adding block arguments here needs to be flushed out a bit more. Do we really need the change to get block arguments in? One option is to not change block arguments but use linalg_ext.index (similar to `linalg.index). That might be easier to land without too many breakages.

@keshavvinayak01
Copy link
Contributor Author

@MaheshRavishankar Thanks for the review, I'll incorporate some of the suggested changes. As for the PR bifurcation, It's a bit difficult to isolate the lowering and the required changes. The lowering relies on block arguments being supported which is a change to the AttentionOp interface itself. Could you be more specific about how you see this split being done?

Yeah, first you should make a change to the AttentionOp itself which will get the block arguments support. Then you can make the changes to the lowering? The part adding block arguments here needs to be flushed out a bit more. Do we really need the change to get block arguments in? One option is to not change block arguments but use linalg_ext.index (similar to `linalg.index). That might be easier to land without too many breakages.

Previously, that's how the implementation looked like, I added support for linalg_ext.index ops inside the attention region. That worked, but bloated up the IR, so I thought this was a simpler solution.

In any case, It's easier to support index ops inside the attention region (With small modifications to the ::verify method), so I'll switch back to it, because it's less likely to break things in the future.

Regardless, I'll have to add tiling support for these index ops in the attention region, which will be a follow up PR.

@MaheshRavishankar
Copy link
Collaborator

Why would you need to add support for index ops while tiling. You shouldn't need to

@keshavvinayak01
Copy link
Contributor Author

keshavvinayak01 commented Nov 26, 2025

I had discussed this with @Groverkss earlier, the goal was to add support for iree_linalg_ext.index in the attention op region to enable index-based logic (e.g., masking for RPE). At the time, he pointed out that using the mask operand in AttentionOp was a hack.

While most tiled ops don’t require index in the region, this op does. It relies on index-dependent semantics like relative position encoding, which can’t be expressed correctly without index after tiling.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants