-
Notifications
You must be signed in to change notification settings - Fork 805
[LinalgExt] Initial support for aten::flex_attention and rewrite to linalg_ext.attention #22441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[LinalgExt] Initial support for aten::flex_attention and rewrite to linalg_ext.attention #22441
Conversation
…ine_attention Signed-off-by: Keshav Vinayak Jha <[email protected]>
Instead of requiring IREE::LinalgExt::IndexOp inside the region body, pass iteration indices (b, h, m, n) directly as block arguments. This simplifies lowering passes. Changes: - Update OnlineAttentionOp/AttentionOp verifiers to accept 1 or 5 block args - Modify applyPostQKMatmulElementwise to provide indices as block arguments - Update FlexAttentionOp lowering to declare 5 block arguments - Document region signature in LinalgExtOps.td The region now receives: (score: f32, b: index, h: index, m: index, n: index) instead of requiring IndexOp operations to query indices. Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
…ention region Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
…mension access Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
- PyTorch’s attention already supplies the correct scale for the base‑e softmax. All the following files are changed to support exp computation instead of exp2. 1. ConvertTorchUnstructuredToLinalgExt: FlexAttentionOpConversion pattern passes use_exp2 = false, which can be used correctly in decomposition. 2. AggregatedOpInterfaceImpl: accepts the use_exp2 flag as an attribute for decomposition and calls computeSubAndExp accordingly. 3. LinalgExtOps.td: Added getUseExp2AttrStr() to both online_attention and attention. 4. ReshapeFusion.cpp: createCollapsedOP was recreating the attention op and stripping all attributes before StripCompilationPass. Change was necessary to support correct decomposition. Signed-off-by: Keshav Vinayak Jha <[email protected]>
28a772f to
f860ccc
Compare
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
|
I can review after you fix all build failures. It looks not ready to me when CI is so red. |
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High level comment:
- Please add period at the end of all your comments.
- The best code is self-documenting. Please delete redundant comments: https://google.github.io/styleguide/cppguide.html#Comments E.g.,
// Yield modified scoreis redundant to me.
I did not point out all the code that has the issue, please do self-review before you request the next review. Thanks!
compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir
Show resolved
Hide resolved
compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp
Outdated
Show resolved
Hide resolved
compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp
Outdated
Show resolved
Hide resolved
compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp
Outdated
Show resolved
Hide resolved
compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp
Outdated
Show resolved
Hide resolved
compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp
Outdated
Show resolved
Hide resolved
compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir
Outdated
Show resolved
Hide resolved
compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp
Outdated
Show resolved
Hide resolved
compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp
Show resolved
Hide resolved
compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
Show resolved
Hide resolved
- General formatting - Removed hardcoded exp2 decomposition attribute Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1. ConvertTorchUnstructuredToLinalgExt.cpp: minor formatting stuff, and removed a redundant comment. 2. unstructured_linalg_ext.mlir: Modified lit test with smaller score function for simplicity; Added warning check incase lse output is expected. 3. AggregatedOpInterfaceImp.cpp: AttentionOp doesn't need the exp2 modification. Removed redundant comments. 4. LinalgExtOps.td: AttentionOp doesn't need getUseExp2AttrStr() attribute. 5. decompose_aggregate_op.mlir: Added another online_attention op variant that passes use_exp2=False in the decompositionConfig and CHECKs that math.exp2 ops are not created. 6. DecomposeAttention.cpp: Extract useExp2 bool from the pass option and attaches a new decompositionconfigattr with the relevant value. 7. Passes.td: iree-linalg-ext-decompose-attention pass now accepts useExp2 as a pass option (Defaults to true to preserve legacy behavior) 8. ReshapeFusion.cpp: I had forced the collapsedOp to again attach the decompositionConfig, don't need this change anymore. Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
… lse return Signed-off-by: Keshav Vinayak Jha <[email protected]>
…eadable Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
MaheshRavishankar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this PR needs to be split-up. There is a signficant change to the ext.attention operation that this is trying to do with the change to the block arguments of the region. I think that is a worthwhile change, but we need to make sure this done in a way that meshes well with all the transformations surrounding the operation. As it is done today this isnt done. Lets split this PR into
- Changes to the ext.attention operation
- Lowering of flex attention operation to ext.attention operation.
| static constexpr int kNumModificationIndices = 4; | ||
|
|
||
| // Makes it convenient to pass around commonly used types. | ||
| struct TypeInfo { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think we want to do this. This is an unnecessary indirection, and potentially this is creating new types in the global context without needing it.
|
|
||
| // Check if the element type is a float. | ||
| Type elementType = queryType.getOptionalDtype(); | ||
| auto floatType = dyn_cast<FloatType>(elementType); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a note. I dont know the torch operator here, but if this operation is expected to have a float type only you can add that check to the verifier of the torch op, and avoid checking it here. Just makes the code less verbose. If needed you can assert that it is float type here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, I'll move it to the ::verify of the torch op itself. I think a lot of checks here can be moved to the verifier.
| Value outputInit = arith::getIdentityValue(arith::AtomicRMWKind::addf, | ||
| floatType, rewriter, loc, | ||
| /*useOnlyFiniteValue=*/true); | ||
| Value outputTensor = tensor::SplatOp::create(rewriter, loc, outputInit, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure we create splat operations. If this is for the output tensor maybe create a linalg.fill operation? Thats what the compiler handles best. In any case as long as when you compile the splat gets converted to a fill later on, its a moot point.
| rewriter, loc, torchType.toBuiltinTensor(), torchTensor); | ||
| } | ||
|
|
||
| void setTypeInfo(PatternRewriter &rewriter, FloatType floatType) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I am usually not a proponent of member functions. This seems to be just a helper method. You can move this and others out of the class as utility functions.
compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp
Show resolved
Hide resolved
| computeDynamicSizes(rewriter, loc, maskShape, maskDynSizes, builtinQuery, | ||
| builtinKey); | ||
|
|
||
| Value maskTensor = tensor::EmptyOp::create(rewriter, loc, maskShape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the mask float type? Thats a lot of storage for just a boolean?
|
|
||
| Region: | ||
| The region body receives the following block arguments: | ||
| - score: the computed score value from Q @ K.T (element type of output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can you make this such that the index operands come first and then the score....
Not Nit: The way the attention operation is setup it can increase dimensionality of the operation. So there isnt necessarily one batch dimension, there could be multiple batch dimensions. Same for head/m/n etc. That probably needs to be accounted for in the change. Just having a single batch dimension is not going to work.
|
@MaheshRavishankar Thanks for the review, I'll incorporate some of the suggested changes. |
Yeah, first you should make a change to the |
Previously, that's how the implementation looked like, I added support for In any case, It's easier to support index ops inside the attention region (With small modifications to the Regardless, I'll have to add tiling support for these index ops in the attention region, which will be a follow up PR. |
|
Why would you need to add support for index ops while tiling. You shouldn't need to |
|
I had discussed this with @Groverkss earlier, the goal was to add support for iree_linalg_ext.index in the attention op region to enable index-based logic (e.g., masking for RPE). At the time, he pointed out that using the While most tiled ops don’t require index in the region, this op does. It relies on index-dependent semantics like relative position encoding, which can’t be expressed correctly without index after tiling. |

This PR adds comprehensive support for converting PyTorch's
torch::hop_flex_attentionoperation to IREE'slinalg_ext.attentionoperation. The implementation includes proper handling of score modification functions, mask functions, and correct softmax computation using base-e exponential (exp) instead of base-2 exponential (exp2).Changes
FlexAttentionOpConversionpattern to converttorch.hop_flex_attentiontolinalg_ext.attentionAttentionOpto pass iteration indices as block arguments to the regionFixes:
expinstead ofexp2:Testing:
aot.exportand compared against eager mode, and I noticed no accuracy losses (On CPU)iree-compile --iree-stream-partitioning-favor=min-peak-memory --iree-hal-target-device=local --iree-hal-local-target-device-backends=llvm-cpu --iree-llvmcpu-target-triple=x86_64-pc-linux-elf --iree-llvmcpu-debug-symbols=false ../torch-mlir/exported_ir.mlir -o ./flex_attention_cpu.vmfb