-
Notifications
You must be signed in to change notification settings - Fork 44
[RemoveLayoutConversions] Fix reduce failed infer type error #377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1013,9 +1013,16 @@ static void hoistConvertOnTopOfExt(ConvertLayoutOp convertOp) { | |
| if (targetType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) | ||
| return; | ||
|
|
||
| #ifndef USE_ROCM | ||
| auto isExtOp = [](Operation *op) { | ||
| return isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op); | ||
| }; | ||
| #else | ||
| auto isExtOp = [](Operation *op) { | ||
| return isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp, | ||
| triton::BroadcastOp, triton::ExpandDimsOp>(op); | ||
| }; | ||
| #endif | ||
| // 1. Take a backward slice of all the tensor dependencies. | ||
| SetVector<Value> slice; | ||
| DenseMap<Value, Attribute> layout; | ||
|
|
@@ -1034,8 +1041,11 @@ static void hoistConvertOnTopOfExt(ConvertLayoutOp convertOp) { | |
| if (isExtOp(op)) { | ||
| SetVector<Value> tempSlice; | ||
| DenseMap<Value, Attribute> tempLayout; | ||
| std::optional<Attribute> srcEncoding = inferSrcEncoding(op, layout[v]); | ||
| if (!srcEncoding) | ||
| return; | ||
| LogicalResult result = getRematerializableSlice( | ||
| op->getOperand(0), layout[v], tempSlice, tempLayout); | ||
| op->getOperand(0), *srcEncoding, tempSlice, tempLayout); | ||
| // If we can rematerialize the rest of the ext slice we can ignore this | ||
| // ext as it won't need a convert. | ||
| if (result.succeeded()) { | ||
|
|
@@ -1053,12 +1063,14 @@ static void hoistConvertOnTopOfExt(ConvertLayoutOp convertOp) { | |
|
|
||
| if (extOp == nullptr) | ||
| return; | ||
| std::optional<Attribute> srcEncoding = | ||
| inferSrcEncoding(extOp, layout[extOp->getResult(0)]); | ||
| // Move the convert before the ext op and rewrite the slice. | ||
| OpBuilder builder(extOp); | ||
| auto tensorType = extOp->getOperand(0).getType().cast<RankedTensorType>(); | ||
| auto newType = | ||
| RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), | ||
| layout[extOp->getResult(0)]); | ||
| *srcEncoding); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part changed in upstream only once (and looks exactly the same), so it should not induct any problems on merge or rebase on top of upstream changes. |
||
| auto newConvertOp = builder.create<ConvertLayoutOp>( | ||
| convertOp.getLoc(), newType, extOp->getOperand(0)); | ||
| IRMapping mapping; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1172,10 +1172,12 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { | |
|
|
||
| // CHECK-LABEL: reduce_cvt2 | ||
| // Match the reduction | ||
| // CHECK-NOT: triton_gpu.convert_layout | ||
| // CHECK: tt.reduce | ||
| // CHECK-SAME: axis = 1 | ||
| // CHECK: (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> | ||
| // CHECK: (tensor<1x256xf32, #{{.*}}>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #{{.*}}}>> | ||
| // CHECK: triton_gpu.convert_layout | ||
| // CHECK: tt.expand_dims | ||
|
Comment on lines
+1175
to
+1180
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is also part of upstream changes |
||
| // CHECK-NOT: triton_gpu.convert_layout | ||
| // CHECK: tt.return | ||
| #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| // RUN: triton-opt %s -split-input-file --tritongpu-remove-layout-conversions -canonicalize | FileCheck %s | ||
| #blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 2], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> | ||
| #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> | ||
| module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { | ||
| // CHECK-LABEL: remove_layout_multiple_outputs | ||
| tt.func public @remove_layout_multiple_outputs(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64, 1> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { | ||
| %second_reduce_input = arith.constant dense<9223372036854775807> : tensor<256x256xi64, #blocked> | ||
| %load_mask = arith.constant dense<1>: tensor<1x256xi1, #blocked> | ||
| %store_mask = arith.constant dense<1>: tensor<256x1xi1, #blocked1> | ||
| %default_load_val = arith.constant dense<0.000000e+00> : tensor<256x256xf16, #blocked> | ||
| %70 = tt.splat %arg0 : (!tt.ptr<f16, 1>) -> tensor<256x256x!tt.ptr<f16, 1>, #blocked> | ||
| %76 = tt.broadcast %load_mask : (tensor<1x256xi1, #blocked>) -> tensor<256x256xi1, #blocked> | ||
| %87 = tt.load %70, %76, %default_load_val {cache = 1 : i32, evict = 2 : i32, isVolatile = false} : tensor<256x256xf16, #blocked> | ||
| %88 = triton_gpu.convert_layout %87 : (tensor<256x256xf16, #blocked>) -> tensor<256x256xf16, #blocked1> | ||
| %89 = arith.extf %87 : tensor<256x256xf16, #blocked> to tensor<256x256xf32, #blocked> | ||
| %108:2 = "tt.reduce"(%89, %second_reduce_input) <{axis = 1 : i32}> ({ | ||
| ^bb0(%arg5: f32, %arg6: i64, %arg7: f32, %arg8: i64): | ||
| tt.reduce.return %arg7, %arg6 : f32, i64 | ||
| }) : (tensor<256x256xf32, #blocked>, tensor<256x256xi64, #blocked>) -> (tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<256xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) | ||
| %111 = tt.splat %arg1 : (!tt.ptr<i64, 1>) -> tensor<256x1x!tt.ptr<i64, 1>, #blocked1> | ||
| %109 = tt.expand_dims %108#1 {axis = 1 : i32} : (tensor<256xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<256x1xi64, #blocked> | ||
| %110 = triton_gpu.convert_layout %109 : (tensor<256x1xi64, #blocked>) -> tensor<256x1xi64, #blocked1> | ||
| tt.store %111, %110, %store_mask {cache = 1 : i32, evict = 1 : i32} : tensor<256x1xi64, #blocked1> | ||
| tt.return | ||
| } | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> | ||
| #blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> | ||
| module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { | ||
| // CHECK-LABEL: make_range_layout_inference | ||
| tt.func public @make_range_layout_inference(%arg0: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}) attributes {noinline = false} { | ||
| %cst_5 = arith.constant dense<1> : tensor<128x1xi1, #blocked> | ||
| %cst_7 = arith.constant dense<1> : tensor<128x4xi1, #blocked> | ||
| %cst_8 = arith.constant dense<128> : tensor<128x1xi32, #blocked> | ||
| %cst_9 = arith.constant dense<1.1> : tensor<128x1xf32, #blocked> | ||
| %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> | ||
| %4 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2> | ||
| %5 = triton_gpu.convert_layout %4 : (tensor<128x1xi32, #blocked2>) -> tensor<128x1xi32, #blocked> | ||
| %18 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<128x4x!tt.ptr<f32, 1>, #blocked> | ||
| %24 = triton_gpu.convert_layout %18 : (tensor<128x4x!tt.ptr<f32, 1>, #blocked>) -> tensor<128x4x!tt.ptr<f32, 1>, #blocked2> | ||
| %25 = triton_gpu.convert_layout %cst_7 : (tensor<128x4xi1, #blocked>) -> tensor<128x4xi1, #blocked2> | ||
| %27 = tt.load %24, %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x4xf32, #blocked2> | ||
| %28 = triton_gpu.convert_layout %27 : (tensor<128x4xf32, #blocked2>) -> tensor<128x4xf32, #blocked> | ||
| %48 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<128x1x!tt.ptr<f32, 1>, #blocked> | ||
| %49 = tt.addptr %48, %5 : tensor<128x1x!tt.ptr<f32, 1>, #blocked>, tensor<128x1xi32, #blocked> | ||
| %52 = tt.load %49, %cst_5 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<128x1xf32, #blocked> | ||
| %56:1 = "tt.reduce"(%28) <{axis = 1 : i32}> ({ | ||
| ^bb0(%arg2: f32, %arg3: f32): | ||
| tt.reduce.return %arg3 : f32 | ||
| }) : (tensor<128x4xf32, #blocked>) -> (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) | ||
| %60 = tt.expand_dims %56#0 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128x1xf32, #blocked> | ||
| %74 = arith.addf %60, %52 : tensor<128x1xf32, #blocked> | ||
| tt.store %48, %74 {cache = 1 : i32, evict = 1 : i32} : tensor<128x1xf32, #blocked> | ||
| tt.return | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is changed multiple times between our current and previous IFU.
I wrapped it with
USE_ROCMto highlight it in case this PR is merged before ongoing IFU