Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,6 @@ LogicalResult getConvertBackwardSlice(

auto updateLayout = [&](Value value, Attribute encoding) {
assert((isa<RankedTensorType>(value.getType())));
slice.insert(value);
Attribute &existing = layout[value];
if (existing && existing != encoding)
return failure();
Expand All @@ -876,7 +875,8 @@ LogicalResult getConvertBackwardSlice(
auto [currentValueUse, encoding] = queue.back();
Value currentValue = currentValueUse->get();
queue.pop_back();
if (!isa<RankedTensorType>(currentValue.getType()))
auto currentValueType = dyn_cast<RankedTensorType>(currentValue.getType());
if (!currentValueType)
continue;
// Skip propagating through for op/while op/ws op results for now.
// TODO: enable this based on needs.
Expand All @@ -885,6 +885,11 @@ LogicalResult getConvertBackwardSlice(
return failure();
if (failed(updateLayout(currentValue, encoding)))
return failure();
// If the value already has the desired encoding, we can stop here without
// adding it to the slice.
if (currentValueType.getEncoding() == encoding)
continue;
slice.insert(currentValue);

// If there is already an existing conversion to the target layout, we don't
// need to propagate to the operands.
Expand Down Expand Up @@ -915,6 +920,7 @@ LogicalResult getConvertBackwardSlice(
continue;
if (failed(updateLayout(result, encoding)))
return failure();
slice.insert(result);
}
if (isFreeConvert(definingOp)) {
enqueue(definingOp->getOpOperand(0), encoding);
Expand Down Expand Up @@ -945,13 +951,6 @@ LogicalResult getConvertBackwardSlice(
}
if (!srcEncoding)
return failure();
// If the infered layout matches the original one we don't need to keep
// propagating.
if (auto operandType =
dyn_cast<RankedTensorType>(operand.get().getType())) {
if (srcEncoding == operandType.getEncoding())
continue;
}
enqueue(operand, srcEncoding);
}
continue;
Expand Down
21 changes: 21 additions & 0 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4190,3 +4190,24 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
tt.return %0#0 : i32
}
}

// -----

#src = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
#dst = #ttg.blocked<{sizePerThread = [1, 2, 2], threadsPerWarp = [1, 1, 1], warpsPerCTA = [1, 1, 1], order = [0, 1, 2]}>
#lin = #ttg.linear<{register = [[0, 1, 0]], lane = [], warp = [], block = []}>
module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 1 : i32} {
// CHECK-LABEL: @test_existing_layout_conflict
// CHECK: ttg.convert_layout
// CHECK: tt.return
tt.func @test_existing_layout_conflict() -> tensor<1x2x2xi32, #dst> {
%r = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #src}>>
%v = tt.expand_dims %r {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #src}>> -> tensor<1x2xi32, #src>
%j = tt.join %v, %v : tensor<1x2xi32, #src> -> tensor<1x2x2xi32, #lin>
%r3 = tt.reshape %v : tensor<1x2xi32, #src> -> tensor<1x2x1xi32, #lin>
%b = tt.broadcast %r3 : tensor<1x2x1xi32, #lin> -> tensor<1x2x2xi32, #lin>
%s = arith.addi %j, %b : tensor<1x2x2xi32, #lin>
%o = ttg.convert_layout %s : tensor<1x2x2xi32, #lin> -> tensor<1x2x2xi32, #dst>
tt.return %o : tensor<1x2x2xi32, #dst>
}
}