[RemoveLayoutConversions] Fix reduce failed infer type error#377
[RemoveLayoutConversions] Fix reduce failed infer type error#377zhanglx13 merged 3 commits intoROCm:triton-mlirfrom
Conversation
| #ifndef USE_ROCM | ||
| auto isExtOp = [](Operation *op) { | ||
| return isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op); | ||
| }; | ||
| #else | ||
| auto isExtOp = [](Operation *op) { | ||
| return isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp, | ||
| triton::BroadcastOp, triton::ExpandDimsOp>(op); | ||
| }; | ||
| #endif |
There was a problem hiding this comment.
This part is changed multiple times between our current and previous IFU.
I wrapped it with USE_ROCM to highlight it in case this PR is merged before ongoing IFU
| auto newType = | ||
| RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), | ||
| layout[extOp->getResult(0)]); | ||
| *srcEncoding); |
There was a problem hiding this comment.
This part changed in upstream only once (and looks exactly the same), so it should not induct any problems on merge or rebase on top of upstream changes.
08903f2 to
086fbcb
Compare
|
For the record, description of the problem here: Consider this example as input: before this change, optimizations moved layout conversion before With this change, layout conversions is moved only before |
| // CHECK-NOT: triton_gpu.convert_layout | ||
| // CHECK: tt.reduce | ||
| // CHECK-SAME: axis = 1 | ||
| // CHECK: (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> | ||
| // CHECK: (tensor<1x256xf32, #{{.*}}>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #{{.*}}}>> | ||
| // CHECK: triton_gpu.convert_layout | ||
| // CHECK: tt.expand_dims |
There was a problem hiding this comment.
This change is also part of upstream changes
|
With this PR I can confirm the argmax unit tests and previous reproducer now pass. But this change seems to have caused a failure for us in another unit test @alefimov-amd let me know if you want me to get together a reproducer for this new issue. |
|
@alefimov-amd Reproducer for the new issue Error log: |
This PR fixes layout propagation algorithm in RemoveLayoutConversions pass. In some cases during rewriteSlice process, reduce operation with multiple outputs rewrites only one output layout, which breaks assumption that both outputs should have same layout. This change is a minimal part of triton-lang#2331 change and small lit test for regression testing.
4728d39 to
9da8508
Compare
jataylo
left a comment
There was a problem hiding this comment.
LGTM. Our core UTs and noted failing cases are passing after this change.
* [RemoveLayoutConversions] Fix reduce failed infer type error This PR fixes layout propagation algorithm in RemoveLayoutConversions pass. In some cases during rewriteSlice process, reduce operation with multiple outputs rewrites only one output layout, which breaks assumption that both outputs should have same layout. This change is a minimal part of triton-lang#2331 change and small lit test for regression testing. * fix combine test * Fix issue with incorrect inference layout of make_range output result
This PR fixes layout propagation algorithm in RemoveLayoutConversions pass. In some cases during rewriteSlice process, reduce operation with multiple outputs rewrites only one output layout, which breaks assumption that both outputs should have same layout.
This change is a minimal part of triton-lang#2331 change and small lit test for regression testing.
Fixes #364