Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,20 @@ struct MoveBroadcastAfterElementwisePattern

auto operands = op->getOperands();
bool seenBroadcast = false;
Type srcType;
ArrayRef<int64_t> srcShape;
for (auto operand : operands) {
auto definingOp = operand.getDefiningOp();
if (!definingOp) {
return mlir::failure();
}

auto getSrcShape = [](triton::BroadcastOp b) {
return b.getSrc().getType().cast<RankedTensorType>().getShape();
};
if (auto broadcastOp = llvm::dyn_cast<triton::BroadcastOp>(definingOp)) {
if (!seenBroadcast) {
seenBroadcast = true;
srcType = broadcastOp.getSrc().getType();
} else if (srcType != broadcastOp.getSrc().getType()) {
srcShape = getSrcShape(broadcastOp);
} else if (srcShape != getSrcShape(broadcastOp)) {
// If the broadcast have different types we cannot re-order.
return mlir::failure();
}
Expand Down
16 changes: 11 additions & 5 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,7 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) {

auto isExtOrBroadcastOp = [](Operation *op) {
return isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp,
triton::BroadcastOp>(op);
triton::BroadcastOp, triton::ExpandDimsOp>(op);
};
// 1. Take a backward slice of all the tensor dependencies.
SetVector<Value> slice;
Expand All @@ -950,8 +950,11 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) {
if (isExtOrBroadcastOp(op)) {
SetVector<Value> tempSlice;
DenseMap<Value, Attribute> tempLayout;
std::optional<Attribute> srcEncoding = inferSrcEncoding(op, layout[v]);
if (!srcEncoding)
return;
LogicalResult result = getRematerializableSlice(
op->getOperand(0), layout[v], tempSlice, tempLayout);
op->getOperand(0), *srcEncoding, tempSlice, tempLayout);
// If we can rematerialize the rest of the ext slice we can ignore this
// ext as it won't need a convert.
if (result.succeeded()) {
Expand All @@ -969,13 +972,16 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) {

if (extOrBroadcatOp == nullptr)
return;
std::optional<Attribute> srcEncoding =
inferSrcEncoding(extOrBroadcatOp, layout[extOrBroadcatOp->getResult(0)]);
if (!srcEncoding)
return;
// Move the convert before the ext op and rewrite the slice.
OpBuilder builder(extOrBroadcatOp);
auto tensorType =
extOrBroadcatOp->getOperand(0).getType().cast<RankedTensorType>();
auto newType =
RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(),
layout[extOrBroadcatOp->getResult(0)]);
auto newType = RankedTensorType::get(
tensorType.getShape(), tensorType.getElementType(), *srcEncoding);
auto newConvertOp = builder.create<ConvertLayoutOp>(
convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0));
IRMapping mapping;
Expand Down
12 changes: 12 additions & 0 deletions test/Triton/reorder-broadcast.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,15 @@ tt.func @test_broadcast_binary_op_pattern(%arg0: tensor<128x1xf32>, %arg1: tenso

tt.return %mul, %mul1 : tensor<128x128xf32>, tensor<128x128xf32>
}

// CHECK-LABEL: @test_broadcast_mix_type_op_pattern
tt.func @test_broadcast_mix_type_op_pattern(%arg0: tensor<128x1xf32>, %arg1: f32, %arg2: tensor<1x128xf32>, %arg3: tensor<128x1xi1>) -> (tensor<128x128xf32>) {
// CHECK: %[[sel:.*]] = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<128x1xi1>, tensor<128x1xf32>
// CHECK-NEXT: %{{.*}} = tt.broadcast %[[sel]] : (tensor<128x1xf32>) -> tensor<128x128xf32>
%broadcast0 = tt.broadcast %arg0 : (tensor<128x1xf32>) -> tensor<128x128xf32>
%broadcast1 = tt.splat %arg1 : (f32) -> tensor<128x128xf32>
%cond = tt.broadcast %arg3 : (tensor<128x1xi1>) -> tensor<128x128xi1>
%sel = arith.select %cond, %broadcast0, %broadcast1 : tensor<128x128xi1>, tensor<128x128xf32>

tt.return %sel : tensor<128x128xf32>
}
4 changes: 3 additions & 1 deletion test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1189,10 +1189,12 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {

// CHECK-LABEL: reduce_cvt2
// Match the reduction
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: tt.reduce
// CHECK-SAME: axis = 1
// CHECK: (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
// CHECK: (tensor<1x256xf32, #{{.*}}>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #{{.*}}}>>
// CHECK: triton_gpu.convert_layout
// CHECK: tt.expand_dims
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: tt.return
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
Expand Down