diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 0a7866bf3efe..19794113c58a 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -260,7 +260,8 @@ SmallVector getMatrixOrder(unsigned rank, bool rowMajor); SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, bool kContig); -bool isExpensiveCat(CatOp cat, Attribute targetEncoding); +// Return true if \p cat would be valid with result encoding \p targetEncoding. +bool isLegalCatEncoding(CatOp cat, Attribute targetEncoding); // Return true if a view between the two types cannot be implemented as a no-op. bool isExpensiveView(Type srcType, Type dstType); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 102eaa39697f..543284928b7a 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -409,16 +409,27 @@ SmallVector orderPerDimImpl(const LinearLayout &ll, return order.takeVector(); } -bool isExpensiveCat(CatOp cat, Attribute targetEncoding) { - // If the new elements per thread is less than the old one, we will need to - // do convert encoding that goes through shared memory anyway. So we - // consider it as expensive. - RankedTensorType tensorTy = cat.getType(); - auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy); - auto shape = tensorTy.getShape(); - auto newTotalElemsPerThread = - gpu::getTotalElemsPerThread(targetEncoding, shape); - return newTotalElemsPerThread < totalElemsPerThread; +static int64_t getNumNonBroadcastRegisters(ArrayRef shape, + Attribute encoding) { + auto kReg = StringAttr::get(encoding.getContext(), "register"); + auto strippedLayout = + toLinearLayout(shape, encoding).removeZeroBasesAlongDim(kReg); + return strippedLayout.getInDimSize(kReg); +} + +static int64_t getNumNonBroadcastRegisters(RankedTensorType tensorType) { + return getNumNonBroadcastRegisters(tensorType.getShape(), + tensorType.getEncoding()); +} + +bool isLegalCatEncoding(CatOp cat, Attribute targetEncoding) { + // Cat lowering concatenates the operands' unique register values. So the + // number of unique register values in the result must be equal to those in + // the operands. + int64_t operandRegs = getNumNonBroadcastRegisters(cat.getLhs().getType()) * 2; + int64_t resultRegs = + getNumNonBroadcastRegisters(cat.getType().getShape(), targetEncoding); + return resultRegs == operandRegs; } static LogicalResult diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 0cd5bd71f31a..2e5125c3c377 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -355,7 +355,7 @@ struct CanonicalizeConvertFromConvert // cvt(cat) -> cat if (auto cat = dyn_cast(arg)) { - if (isExpensiveCat(cat, op.getType().getEncoding())) + if (!isLegalCatEncoding(cat, op.getType().getEncoding())) return failure(); rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index f8fe3361ce0d..aa414cd518e7 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -604,26 +604,10 @@ bool isExpensiveLoadOrStore(Operation *op) { return true; } -bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { - if (!op) - return true; - if (isa(op)) - return isExpensiveLoadOrStore(op); - if (isa(op)) - return triton::gpu::isExpensiveCat(cast(op), targetEncoding); - if (isa(op)) - return true; - if (isa( - op)) - return true; - return false; -} - bool canUseResultEncoding(Operation *op, Attribute targetEncoding) { if (isa(op)) - return !triton::gpu::isExpensiveCat(cast(op), - targetEncoding); + return triton::gpu::isLegalCatEncoding(cast(op), + targetEncoding); if (auto convert = dyn_cast(op)) { if (mlir::isa(targetEncoding)) { auto srcEncoding = convert.getSrc().getType().getEncoding(); diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 602b0886a111..3a6286567afd 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -4176,3 +4176,47 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 1 : i32} tt.return %o : tensor<1x2x2xi32, #dst> } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#linear = #ttg.linear<{register = [[1], [16]], lane = [[0], [0], [2], [4], [8]], warp = [[0], [0]], block = []}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @cat_incompatible_target_keeps_convert + tt.func public @cat_incompatible_target_keeps_convert(%out: !tt.ptr) { + %lhs = arith.constant dense<0> : tensor<16xi32, #blocked> + %rhs = arith.constant dense<1> : tensor<16xi32, #blocked> + // CHECK: %[[CAT:[^ ]+]] = tt.cat + %cat = tt.cat %lhs, %rhs : tensor<16xi32, #blocked> -> tensor<32xi32, #blocked2> + // CHECK: %[[CVT:[^ ]+]] = ttg.convert_layout %[[CAT]] + %cvt = ttg.convert_layout %cat {allocation.offset = 0 : i32} : tensor<32xi32, #blocked2> -> tensor<32xi32, #linear> + %ptr = tt.splat %out : !tt.ptr -> tensor<32x!tt.ptr, #linear> + // CHECK: tt.store {{.*}}, %[[CVT]] + tt.store %ptr, %cvt : tensor<32x!tt.ptr, #linear> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#linear_bcast = #ttg.linear<{register = [[0]], lane = [[1], [2], [4], [8], [16]], warp = [[0], [0]], block = []}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @cat_target_adds_broadcasting_keeps_convert + tt.func public @cat_target_adds_broadcasting_keeps_convert(%out: !tt.ptr) { + %lhs = arith.constant dense<0> : tensor<16xi32, #blocked> + %rhs = arith.constant dense<1> : tensor<16xi32, #blocked> + // CHECK: %[[CAT:[^ ]+]] = tt.cat + %cat = tt.cat %lhs, %rhs : tensor<16xi32, #blocked> -> tensor<32xi32, #blocked2> + // CHECK: %[[CVT:[^ ]+]] = ttg.convert_layout %[[CAT]] + %cvt = ttg.convert_layout %cat : tensor<32xi32, #blocked2> -> tensor<32xi32, #linear_bcast> + %ptr = tt.splat %out : !tt.ptr -> tensor<32x!tt.ptr, #linear_bcast> + // CHECK: tt.store {{.*}}, %[[CVT]] + tt.store %ptr, %cvt : tensor<32x!tt.ptr, #linear_bcast> + tt.return + } +}