diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 220f2d1c3df2..21b2490d15bf 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -864,7 +864,6 @@ LogicalResult getConvertBackwardSlice( auto updateLayout = [&](Value value, Attribute encoding) { assert((isa(value.getType()))); - slice.insert(value); Attribute &existing = layout[value]; if (existing && existing != encoding) return failure(); @@ -876,7 +875,8 @@ LogicalResult getConvertBackwardSlice( auto [currentValueUse, encoding] = queue.back(); Value currentValue = currentValueUse->get(); queue.pop_back(); - if (!isa(currentValue.getType())) + auto currentValueType = dyn_cast(currentValue.getType()); + if (!currentValueType) continue; // Skip propagating through for op/while op/ws op results for now. // TODO: enable this based on needs. @@ -885,6 +885,11 @@ LogicalResult getConvertBackwardSlice( return failure(); if (failed(updateLayout(currentValue, encoding))) return failure(); + // If the value already has the desired encoding, we can stop here without + // adding it to the slice. + if (currentValueType.getEncoding() == encoding) + continue; + slice.insert(currentValue); // If there is already an existing conversion to the target layout, we don't // need to propagate to the operands. @@ -915,6 +920,7 @@ LogicalResult getConvertBackwardSlice( continue; if (failed(updateLayout(result, encoding))) return failure(); + slice.insert(result); } if (isFreeConvert(definingOp)) { enqueue(definingOp->getOpOperand(0), encoding); @@ -945,13 +951,6 @@ LogicalResult getConvertBackwardSlice( } if (!srcEncoding) return failure(); - // If the infered layout matches the original one we don't need to keep - // propagating. - if (auto operandType = - dyn_cast(operand.get().getType())) { - if (srcEncoding == operandType.getEncoding()) - continue; - } enqueue(operand, srcEncoding); } continue; diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index b89d933221d3..571bf5251520 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -4190,3 +4190,24 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr tt.return %0#0 : i32 } } + +// ----- + +#src = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 1], warpsPerCTA = [1, 1], order = [0, 1]}> +#dst = #ttg.blocked<{sizePerThread = [1, 2, 2], threadsPerWarp = [1, 1, 1], warpsPerCTA = [1, 1, 1], order = [0, 1, 2]}> +#lin = #ttg.linear<{register = [[0, 1, 0]], lane = [], warp = [], block = []}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 1 : i32} { + // CHECK-LABEL: @test_existing_layout_conflict + // CHECK: ttg.convert_layout + // CHECK: tt.return + tt.func @test_existing_layout_conflict() -> tensor<1x2x2xi32, #dst> { + %r = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #src}>> + %v = tt.expand_dims %r {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #src}>> -> tensor<1x2xi32, #src> + %j = tt.join %v, %v : tensor<1x2xi32, #src> -> tensor<1x2x2xi32, #lin> + %r3 = tt.reshape %v : tensor<1x2xi32, #src> -> tensor<1x2x1xi32, #lin> + %b = tt.broadcast %r3 : tensor<1x2x1xi32, #lin> -> tensor<1x2x2xi32, #lin> + %s = arith.addi %j, %b : tensor<1x2x2xi32, #lin> + %o = ttg.convert_layout %s : tensor<1x2x2xi32, #lin> -> tensor<1x2x2xi32, #dst> + tt.return %o : tensor<1x2x2xi32, #dst> + } +}