-
Notifications
You must be signed in to change notification settings - Fork 839
[Dispatch Creation] Don't fuse uses from above #22708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -127,6 +127,19 @@ class FusionGroup { | |
| FailureOr<AffineMap> getRootParallelLoopToOpMap(Operation *op) const; | ||
|
|
||
| bool isFusable(Operation *op) const { | ||
| // We only handle fusion across operation's operands. Don't fuse if the | ||
| // operation is using values in the fusion group in it's body. | ||
| bool hasUseFromAbove = false; | ||
| mlir::visitUsedValuesDefinedAbove( | ||
| op->getRegions(), [&](OpOperand *operand) { | ||
| if (loopMaps.contains(operand->get().getDefiningOp())) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a map that contains all of the operations currently in the "fusion group". |
||
| hasUseFromAbove = true; | ||
| } | ||
| }); | ||
| if (hasUseFromAbove) { | ||
| return false; | ||
| } | ||
|
|
||
| FailureOr<AffineMap> maybeMap = getRootParallelLoopToOpMap(op); | ||
| if (failed(maybeMap)) { | ||
| return false; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1837,3 +1837,41 @@ util.func public @no_fusion_across_blocks(%arg0: tensor<3x2xf32>) -> tensor<f32> | |
| // CHECK-SAME: ins(%[[DISPATCH0]], %[[FILL]] | ||
| // CHECK: flow.return %[[DIV]] | ||
| // CHECK: util.return %[[DISPATCH1]] | ||
|
|
||
| // ----- | ||
|
|
||
| util.func public @no_fusion_use_from_above(%arg0 : tensor<?x?xf32>, | ||
| %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> { | ||
| %c0 = arith.constant 0 : index | ||
| %c1 = arith.constant 1 : index | ||
| %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> | ||
| %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32> | ||
| %empty = tensor.empty(%d0, %d1) : tensor<?x?xf32> | ||
| %matmul = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) | ||
| outs(%empty : tensor<?x?xf32>) -> tensor<?x?xf32> | ||
| %empty2 = tensor.empty(%d0, %d1) : tensor<?x?xf32> | ||
| %consumer = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} | ||
| ins(%matmul : tensor<?x?xf32>) | ||
| outs(%empty2 : tensor<?x?xf32>) { | ||
| ^bb0(%in: f32, %out: f32): | ||
| %c0_idx = arith.constant 0 : index | ||
| %c1_idx = arith.constant 1 : index | ||
| %extracted = tensor.extract %matmul[%c0_idx, %c1_idx] : tensor<?x?xf32> | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting example. First, I dont know why we are lowering it this way. This seems awfully strange lowering. Are we sure we know whats happening with the input lowering of torch-mlir. I dont think we want to end up with this lowering. It does so happen that the SSA violation is being used to prevent this fusion, but we should prevent this fusion "structurally" as well, i.e. when we decide if this is a fusable consumer, we should mark it as not fusable. due to the
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. EDIT: I see that you have checked that exactly below. Nice!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test might be a bit over-simplified, making it look stranger than it is. I've lost the exact IR from CI that was causing this issue but I should be able to re-create it on #22341 |
||
| %sum = arith.addf %in, %extracted : f32 | ||
| linalg.yield %sum : f32 | ||
| } -> tensor<?x?xf32> | ||
| util.return %consumer : tensor<?x?xf32> | ||
| } | ||
| // CHECK-LABEL: util.func public @no_fusion_use_from_above( | ||
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32> | ||
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32> | ||
| // CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region | ||
| // CHECK: %[[MATMUL:.+]] = linalg.matmul | ||
| // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] | ||
| // CHECK: flow.return %[[MATMUL]] | ||
| // CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region | ||
| // CHECK: %[[CONSUMER:.+]] = linalg.generic | ||
| // CHECK-SAME: ins(%[[DISPATCH0]] | ||
| // CHECK: tensor.extract %[[DISPATCH0]] | ||
| // CHECK: flow.return %[[CONSUMER]] | ||
| // CHECK: util.return %[[DISPATCH1]] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So just to confirm,
isAncestorincludes the operation as well.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes!