Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,21 @@ struct MoveBroadcastAfterElementwisePattern

auto operands = op->getOperands();
bool seenBroadcast = false;
Type srcType;
for (auto operand : operands) {
auto definingOp = operand.getDefiningOp();
if (!definingOp) {
return mlir::failure();
}

if (auto broadcastOp = llvm::dyn_cast<triton::BroadcastOp>(definingOp)) {
if (seenBroadcast) {
// Only support one broadcasted argument for now
if (!seenBroadcast) {
seenBroadcast = true;
srcType = broadcastOp.getSrc().getType();
} else if (srcType != broadcastOp.getSrc().getType()) {
// If the broadcast have different types we cannot re-order.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice improvement. I think you can generalize it slightly further by having a common srcShape and allowing the scalar types to differ. This would allow the pattern to work for AddPtrOp for example.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for the delay, thanks for the advice. I'm sending another PR to handle this.

return mlir::failure();
}
seenBroadcast = true;
} else if (!isSplat(definingOp)) {
// Not splat or broadcast
return mlir::failure();
Expand All @@ -149,17 +152,17 @@ struct MoveBroadcastAfterElementwisePattern
}
}

auto src = broadcastOp.getSrc();
auto srcTy = src.getType().dyn_cast<RankedTensorType>();
auto srcTy = broadcastOp.getSrc().getType().dyn_cast<RankedTensorType>();
auto srcShape = srcTy.getShape();
auto srcEncoding = srcTy.getEncoding();

// Reshape operands to match srcShape
llvm::SmallVector<Value, 4> newOperands;
for (auto operand : operands) {
auto definingOp = operand.getDefiningOp();
if (llvm::isa<triton::BroadcastOp>(definingOp)) {
newOperands.push_back(src);
if (auto broadcastSrcOp =
llvm::dyn_cast<triton::BroadcastOp>(definingOp)) {
newOperands.push_back(broadcastSrcOp.getSrc());
continue;
}
auto elemTy =
Expand Down
35 changes: 19 additions & 16 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,7 @@ static void backwardRematerialization(ConvertLayoutOp convertOp) {

// For convert left we try to hoist them above type extension to reduce the cost
// of the convert.
static void hoistConvertOnTopOfExt(ConvertLayoutOp convertOp) {
static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) {
// we don't want to rematerialize any conversion to/from shared
if (triton::gpu::isSharedEncoding(convertOp.getResult()) ||
triton::gpu::isSharedEncoding(convertOp.getOperand()))
Expand All @@ -926,25 +926,27 @@ static void hoistConvertOnTopOfExt(ConvertLayoutOp convertOp) {
if (targetType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
return;

auto isExtOp = [](Operation *op) {
return isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op);
auto isExtOrBroadcastOp = [](Operation *op) {
return isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp,
triton::BroadcastOp>(op);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this apply to ExpandDim and Splat as well?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Splat takes a scalar source so it is already handled. I'll extend to ExpandDim

};
// 1. Take a backward slice of all the tensor dependencies.
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
LogicalResult result = getRematerializableSlice(
convertOp.getOperand(), targetType.getEncoding(), slice, layout, isExtOp);
LogicalResult result =
getRematerializableSlice(convertOp.getOperand(), targetType.getEncoding(),
slice, layout, isExtOrBroadcastOp);
if (result.failed())
return;

Operation *extOp = nullptr;
Operation *extOrBroadcatOp = nullptr;
unsigned sliceSize = slice.size();
for (unsigned i = 0; i < sliceSize; i++) {
Value v = slice[i];
Operation *op = v.getDefiningOp();
if (!op)
continue;
if (isExtOp(op)) {
if (isExtOrBroadcastOp(op)) {
SetVector<Value> tempSlice;
DenseMap<Value, Attribute> tempLayout;
LogicalResult result = getRematerializableSlice(
Expand All @@ -958,24 +960,25 @@ static void hoistConvertOnTopOfExt(ConvertLayoutOp convertOp) {
}
// Only apply it if there is a single ext op otherwise we would have to
// duplicate the convert.
if (extOp != nullptr)
if (extOrBroadcatOp != nullptr)
return;
extOp = op;
extOrBroadcatOp = op;
}
}

if (extOp == nullptr)
if (extOrBroadcatOp == nullptr)
return;
// Move the convert before the ext op and rewrite the slice.
OpBuilder builder(extOp);
auto tensorType = extOp->getOperand(0).getType().cast<RankedTensorType>();
OpBuilder builder(extOrBroadcatOp);
auto tensorType =
extOrBroadcatOp->getOperand(0).getType().cast<RankedTensorType>();
auto newType =
RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(),
layout[extOp->getResult(0)]);
layout[extOrBroadcatOp->getResult(0)]);
auto newConvertOp = builder.create<ConvertLayoutOp>(
convertOp.getLoc(), newType, extOp->getOperand(0));
convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0));
IRMapping mapping;
mapping.map(extOp->getOperand(0), newConvertOp.getResult());
mapping.map(extOrBroadcatOp->getOperand(0), newConvertOp.getResult());
// 3. Rewrite the slice.
rewriteSlice(slice, layout, convertOp, mapping);
}
Expand All @@ -994,7 +997,7 @@ static void hoistConvert(ModuleOp module) {
module.walk(
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
for (ConvertLayoutOp convertOp : convertOps) {
hoistConvertOnTopOfExt(convertOp);
hoistConvertOnTopOfExtOrBroadcast(convertOp);
}
}

Expand Down
15 changes: 15 additions & 0 deletions test/Triton/reorder-broadcast.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,18 @@ tt.func @test_broadcast_elementwise_pattern(%arg0: tensor<128x1xf32>) -> (tensor

tt.return %abs, %add : tensor<128x128xf32>, tensor<128x32xf32>
}

// CHECK-LABEL: @test_broadcast_binary_op_pattern
tt.func @test_broadcast_binary_op_pattern(%arg0: tensor<128x1xf32>, %arg1: tensor<128x1xf32>, %arg2: tensor<1x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
// CHECK: %[[mul:.*]] = arith.mulf %{{.*}}, %{{.*}} : tensor<128x1xf32>
// CHECK-NEXT: %{{.*}} = tt.broadcast %[[mul]] : (tensor<128x1xf32>) -> tensor<128x128xf32>
%broadcast0 = tt.broadcast %arg0 : (tensor<128x1xf32>) -> tensor<128x128xf32>
%broadcast1 = tt.broadcast %arg1 : (tensor<128x1xf32>) -> tensor<128x128xf32>
%mul = arith.mulf %broadcast0, %broadcast1 : tensor<128x128xf32>

// CHECK: %[[mul:.*]] = arith.mulf %{{.*}}, %{{.*}} : tensor<128x128xf32>
%broadcast2 = tt.broadcast %arg2 : (tensor<1x128xf32>) -> tensor<128x128xf32>
%mul1 = arith.mulf %broadcast0, %broadcast2 : tensor<128x128xf32>

tt.return %mul, %mul1 : tensor<128x128xf32>, tensor<128x128xf32>
}
17 changes: 17 additions & 0 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>

#layout2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#layout3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>


module attributes {"triton_gpu.num-warps" = 4 : i32} {

// CHECK: [[$target_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
Expand Down Expand Up @@ -105,6 +109,19 @@ tt.func @hoist_above_ext2(%arg0: tensor<1024xf16, #layout0>, %arg1: f16) -> tens
tt.return %4 : tensor<1024xf32, #layout1>
}

// Hoist the convert on top of broadcast to make it cheaper.
// CHECK-LABEL: hoist_above_broadcast
tt.func @hoist_above_broadcast(%arg0: tensor<1024x1xf32, #layout2>, %arg1: f32) -> tensor<1024x128xf32, #layout3> {
// CHECK: %[[CVT:.+]] = triton_gpu.convert_layout
// CHECK: tt.broadcast %[[CVT]]
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: tt.return
%0 = tt.broadcast %arg0 : (tensor<1024x1xf32, #layout2>) -> tensor<1024x128xf32, #layout2>
%1 = tt.splat %arg1 : (f32) -> tensor<1024x128xf32, #layout2>
%2 = arith.addf %0, %1 : tensor<1024x128xf32, #layout2>
%3 = triton_gpu.convert_layout %2 : (tensor<1024x128xf32, #layout2>) -> tensor<1024x128xf32, #layout3>
tt.return %3 : tensor<1024x128xf32, #layout3>
}


// CHECK-LABEL: if
Expand Down