Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor);
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kContig);

bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
// Return true if \p cat would be valid with result encoding \p targetEncoding.
bool isLegalCatEncoding(CatOp cat, Attribute targetEncoding);

// Return true if a view between the two types cannot be implemented as a no-op.
bool isExpensiveView(Type srcType, Type dstType);
Expand Down
31 changes: 21 additions & 10 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,16 +409,27 @@ SmallVector<unsigned> orderPerDimImpl(const LinearLayout &ll,
return order.takeVector();
}

bool isExpensiveCat(CatOp cat, Attribute targetEncoding) {
// If the new elements per thread is less than the old one, we will need to
// do convert encoding that goes through shared memory anyway. So we
// consider it as expensive.
RankedTensorType tensorTy = cat.getType();
auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy);
auto shape = tensorTy.getShape();
auto newTotalElemsPerThread =
gpu::getTotalElemsPerThread(targetEncoding, shape);
return newTotalElemsPerThread < totalElemsPerThread;
static int64_t getNumNonBroadcastRegisters(ArrayRef<int64_t> shape,
Attribute encoding) {
auto kReg = StringAttr::get(encoding.getContext(), "register");
auto strippedLayout =
toLinearLayout(shape, encoding).removeZeroBasesAlongDim(kReg);
return strippedLayout.getInDimSize(kReg);
Comment thread
peterbell10 marked this conversation as resolved.
}

static int64_t getNumNonBroadcastRegisters(RankedTensorType tensorType) {
return getNumNonBroadcastRegisters(tensorType.getShape(),
tensorType.getEncoding());
}

bool isLegalCatEncoding(CatOp cat, Attribute targetEncoding) {
// Cat lowering concatenates the operands' unique register values. So the
// number of unique register values in the result must be equal to those in
// the operands.
int64_t operandRegs = getNumNonBroadcastRegisters(cat.getLhs().getType()) * 2;
int64_t resultRegs =
getNumNonBroadcastRegisters(cat.getType().getShape(), targetEncoding);
return resultRegs == operandRegs;
}

static LogicalResult
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ struct CanonicalizeConvertFromConvert

// cvt(cat) -> cat
if (auto cat = dyn_cast<CatOp>(arg)) {
if (isExpensiveCat(cat, op.getType().getEncoding()))
if (!isLegalCatEncoding(cat, op.getType().getEncoding()))
return failure();

rewriter.replaceOpWithNewOp<CatOp>(op, op->getResult(0).getType(),
Expand Down
20 changes: 2 additions & 18 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,26 +604,10 @@ bool isExpensiveLoadOrStore(Operation *op) {
return true;
}

bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) {
if (!op)
return true;
if (isa<triton::LoadOp, triton::StoreOp>(op))
return isExpensiveLoadOrStore(op);
if (isa<triton::CatOp>(op))
return triton::gpu::isExpensiveCat(cast<triton::CatOp>(op), targetEncoding);
if (isa<triton::gpu::AsyncCopyGlobalToLocalOp, triton::AtomicRMWOp,
triton::AtomicCASOp, triton::DotOp>(op))
return true;
if (isa<scf::YieldOp, scf::ForOp, scf::IfOp, scf::WhileOp, scf::ConditionOp>(
op))
return true;
return false;
}

bool canUseResultEncoding(Operation *op, Attribute targetEncoding) {
if (isa<triton::CatOp>(op))
return !triton::gpu::isExpensiveCat(cast<triton::CatOp>(op),
targetEncoding);
return triton::gpu::isLegalCatEncoding(cast<triton::CatOp>(op),
targetEncoding);
if (auto convert = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
if (mlir::isa<triton::gpu::NvidiaMmaEncodingAttr>(targetEncoding)) {
auto srcEncoding = convert.getSrc().getType().getEncoding();
Expand Down
44 changes: 44 additions & 0 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4176,3 +4176,47 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 1 : i32}
tt.return %o : tensor<1x2x2xi32, #dst>
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#linear = #ttg.linear<{register = [[1], [16]], lane = [[0], [0], [2], [4], [8]], warp = [[0], [0]], block = []}>

module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @cat_incompatible_target_keeps_convert
tt.func public @cat_incompatible_target_keeps_convert(%out: !tt.ptr<i32>) {
%lhs = arith.constant dense<0> : tensor<16xi32, #blocked>
%rhs = arith.constant dense<1> : tensor<16xi32, #blocked>
// CHECK: %[[CAT:[^ ]+]] = tt.cat
%cat = tt.cat %lhs, %rhs : tensor<16xi32, #blocked> -> tensor<32xi32, #blocked2>
// CHECK: %[[CVT:[^ ]+]] = ttg.convert_layout %[[CAT]]
%cvt = ttg.convert_layout %cat {allocation.offset = 0 : i32} : tensor<32xi32, #blocked2> -> tensor<32xi32, #linear>
%ptr = tt.splat %out : !tt.ptr<i32> -> tensor<32x!tt.ptr<i32>, #linear>
// CHECK: tt.store {{.*}}, %[[CVT]]
tt.store %ptr, %cvt : tensor<32x!tt.ptr<i32>, #linear>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#linear_bcast = #ttg.linear<{register = [[0]], lane = [[1], [2], [4], [8], [16]], warp = [[0], [0]], block = []}>

module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @cat_target_adds_broadcasting_keeps_convert
tt.func public @cat_target_adds_broadcasting_keeps_convert(%out: !tt.ptr<i32>) {
%lhs = arith.constant dense<0> : tensor<16xi32, #blocked>
%rhs = arith.constant dense<1> : tensor<16xi32, #blocked>
// CHECK: %[[CAT:[^ ]+]] = tt.cat
%cat = tt.cat %lhs, %rhs : tensor<16xi32, #blocked> -> tensor<32xi32, #blocked2>
// CHECK: %[[CVT:[^ ]+]] = ttg.convert_layout %[[CAT]]
%cvt = ttg.convert_layout %cat : tensor<32xi32, #blocked2> -> tensor<32xi32, #linear_bcast>
%ptr = tt.splat %out : !tt.ptr<i32> -> tensor<32x!tt.ptr<i32>, #linear_bcast>
// CHECK: tt.store {{.*}}, %[[CVT]]
tt.store %ptr, %cvt : tensor<32x!tt.ptr<i32>, #linear_bcast>
tt.return
}
}
Loading