diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index 545b7709d9c5..765e531e7b02 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -87,6 +87,11 @@ class DialectInferLayoutInterface verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, Attribute operandEncodingB) const = 0; + // Verify that the encodings are compatible to be used together in a cat + // operation. + virtual LogicalResult + verifyCatOpEncodingCompatibility(Operation *op) const = 0; + virtual LogicalResult inferFp4ToFpOpEncoding(ArrayRef shape, int axis, Attribute inEnc, Attribute &outEnc, bool fwdInference, diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 5c637e3b8368..16e5b3e966a7 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -491,6 +491,8 @@ def TT_CatOp : TT_Op<"cat", [NoMemoryEffect, let results = (outs TT_Tensor:$result); let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; + + let hasVerifier = 1; } def TT_JoinOp : TT_Op<"join", [ diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 0a7866bf3efe..19794113c58a 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -260,7 +260,8 @@ SmallVector getMatrixOrder(unsigned rank, bool rowMajor); SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, bool kContig); -bool isExpensiveCat(CatOp cat, Attribute targetEncoding); +// Return true if \p cat would be valid with result encoding \p targetEncoding. +bool isLegalCatEncoding(CatOp cat, Attribute targetEncoding); // Return true if a view between the two types cannot be implemented as a no-op. bool isExpensiveView(Type srcType, Type dstType); diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 90a472038e7d..bda1d54aea7f 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -169,15 +169,7 @@ struct CatOpConversion : public ConvertOpToLLVMPattern { for (Value v : rhsVals) retVals.push_back(v); - if (retVals.size() != strippedDstLayout.getInDimSize(kReg)) { - return op->emitError() - << "tt.cat lowering expected " - << strippedDstLayout.getInDimSize(kReg) - << " (non-broadcast) register values for the result, but got " - << retVals.size() - << ". (hint: this usually means the operands/result encodings are " - "incompatible for the current CatOp lowering)"; - } + assert(retVals.size() == strippedDstLayout.getInDimSize(kReg)); // Re-introduce broadcasting if the destination expects it. if (!removeBroadcastDst.isIdentity()) diff --git a/lib/Dialect/Gluon/IR/Dialect.cpp b/lib/Dialect/Gluon/IR/Dialect.cpp index 0a18ec8522f7..7054527167bc 100644 --- a/lib/Dialect/Gluon/IR/Dialect.cpp +++ b/lib/Dialect/Gluon/IR/Dialect.cpp @@ -65,6 +65,10 @@ struct GluonInferLayoutInterface : public triton::DialectInferLayoutInterface { return success(); } + LogicalResult verifyCatOpEncodingCompatibility(Operation *op) const override { + return success(); + } + LogicalResult verifyLayoutsAreEqual(ArrayRef shape, Attribute expected, Attribute got, diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index b7f61e6d4002..313fec9b7b2a 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -862,6 +862,31 @@ OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) { return foldViewLikeOp(*this, adaptor.getSrc()); } +//-- CatOp -- +LogicalResult CatOp::verify() { + RankedTensorType lhsTy = getLhs().getType(); + RankedTensorType resultTy = getType(); + + int64_t operandElements = lhsTy.getNumElements() * 2; + if (resultTy.getNumElements() != operandElements) { + return emitOpError("result element count must equal the sum of the " + "operand element counts, expected ") + << operandElements << " but got " << resultTy.getNumElements(); + } + + Attribute operandEnc = lhsTy.getEncoding(); + Attribute resultEnc = resultTy.getEncoding(); + if (!!operandEnc != !!resultEnc) { + return emitOpError("requires that either (a) operands and result all have " + "encodings, or (b) none do."); + } + if (!resultEnc) + return success(); + + auto interface = cast(&resultEnc.getDialect()); + return interface->verifyCatOpEncodingCompatibility(getOperation()); +} + //-- ReshapeOp -- void ReshapeOp::build(OpBuilder &builder, OperationState &state, diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 102eaa39697f..0c61bab02e52 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -409,16 +409,27 @@ SmallVector orderPerDimImpl(const LinearLayout &ll, return order.takeVector(); } -bool isExpensiveCat(CatOp cat, Attribute targetEncoding) { - // If the new elements per thread is less than the old one, we will need to - // do convert encoding that goes through shared memory anyway. So we - // consider it as expensive. - RankedTensorType tensorTy = cat.getType(); - auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy); - auto shape = tensorTy.getShape(); - auto newTotalElemsPerThread = - gpu::getTotalElemsPerThread(targetEncoding, shape); - return newTotalElemsPerThread < totalElemsPerThread; +static int64_t getNumNonBroadcastRegisters(ArrayRef shape, + Attribute encoding) { + auto kReg = StringAttr::get(encoding.getContext(), "register"); + auto strippedLayout = + toLinearLayout(shape, encoding).removeZeroBasesAlongDim(kReg); + return strippedLayout.getInDimSize(kReg); +} + +static int64_t getNumNonBroadcastRegisters(RankedTensorType tensorType) { + return getNumNonBroadcastRegisters(tensorType.getShape(), + tensorType.getEncoding()); +} + +bool isLegalCatEncoding(CatOp cat, Attribute targetEncoding) { + // Cat lowering concatenates the operands' unique register values. So the + // number of unique register values in the result must be equal to those in + // the operands. + int64_t operandRegs = getNumNonBroadcastRegisters(cat.getLhs().getType()) * 2; + int64_t resultRegs = + getNumNonBroadcastRegisters(cat.getType().getShape(), targetEncoding); + return resultRegs == operandRegs; } static LogicalResult @@ -3017,6 +3028,20 @@ struct TritonGPUInferLayoutInterface return success(); } + LogicalResult verifyCatOpEncodingCompatibility(Operation *op) const override { + auto cat = cast(op); + int64_t operandRegs = + getNumNonBroadcastRegisters(cat.getLhs().getType()) * 2; + int64_t resultRegs = getNumNonBroadcastRegisters(cat.getType()); + if (resultRegs != operandRegs) { + return op->emitError("tt.cat result encoding requires ") + << resultRegs + << " non-broadcast register values, but operands provide " + << operandRegs; + } + return success(); + } + // Given a src shape + encoding and a dst shape, our goal is to compute a dst // encoding that makes the reshape a "nop". That is, if GPU thread [x,y,z] // contains elements [a,b,c,d] before the reshape, it contains those same diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 0cd5bd71f31a..2e5125c3c377 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -355,7 +355,7 @@ struct CanonicalizeConvertFromConvert // cvt(cat) -> cat if (auto cat = dyn_cast(arg)) { - if (isExpensiveCat(cat, op.getType().getEncoding())) + if (!isLegalCatEncoding(cat, op.getType().getEncoding())) return failure(); rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index f8fe3361ce0d..aa414cd518e7 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -604,26 +604,10 @@ bool isExpensiveLoadOrStore(Operation *op) { return true; } -bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { - if (!op) - return true; - if (isa(op)) - return isExpensiveLoadOrStore(op); - if (isa(op)) - return triton::gpu::isExpensiveCat(cast(op), targetEncoding); - if (isa(op)) - return true; - if (isa( - op)) - return true; - return false; -} - bool canUseResultEncoding(Operation *op, Attribute targetEncoding) { if (isa(op)) - return !triton::gpu::isExpensiveCat(cast(op), - targetEncoding); + return triton::gpu::isLegalCatEncoding(cast(op), + targetEncoding); if (auto convert = dyn_cast(op)) { if (mlir::isa(targetEncoding)) { auto srcEncoding = convert.getSrc().getType().getEncoding(); diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir index a2de394abc2d..a2da914eba7d 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir @@ -295,16 +295,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON-LABEL: join_cat_transitive_nonneg tt.func @join_cat_transitive_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked1> %1 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32, #blocked1> %2 = tt.join %0, %1 : tensor<8xi32, #blocked1> -> tensor<8x2xi32, #blocked> - %3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked1> - %4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked1> - %5 = tt.join %3, %4 : tensor<4xi32, #blocked1> -> tensor<4x2xi32, #blocked> - %6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked> -> tensor<8x2xi32, #blocked> + %3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked2> + %4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked2> + %5 = tt.join %3, %4 : tensor<4xi32, #blocked2> -> tensor<4x2xi32, #blocked3> + %6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked3> -> tensor<8x2xi32, #blocked> %7 = arith.addi %2, %6 : tensor<8x2xi32, #blocked> %zeros = arith.constant dense<0> : tensor<8x1xi32, #blocked> %ones = arith.constant dense<1> : tensor<8x1xi32, #blocked> diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir index 94413edc06d1..922717d546b7 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -335,16 +335,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON-LABEL: join_cat_transitive_nonneg tt.func @join_cat_transitive_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked1> %1 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32, #blocked1> %2 = tt.join %0, %1 : tensor<8xi32, #blocked1> -> tensor<8x2xi32, #blocked> - %3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked1> - %4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked1> - %5 = tt.join %3, %4 : tensor<4xi32, #blocked1> -> tensor<4x2xi32, #blocked> - %6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked> -> tensor<8x2xi32, #blocked> + %3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked2> + %4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked2> + %5 = tt.join %3, %4 : tensor<4xi32, #blocked2> -> tensor<4x2xi32, #blocked3> + %6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked3> -> tensor<8x2xi32, #blocked> %7 = arith.addi %2, %6 : tensor<8x2xi32, #blocked> %zeros = arith.constant dense<0> : tensor<8x1xi32, #blocked> %ones = arith.constant dense<1> : tensor<8x1xi32, #blocked> diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 602b0886a111..3a6286567afd 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -4176,3 +4176,47 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 1 : i32} tt.return %o : tensor<1x2x2xi32, #dst> } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#linear = #ttg.linear<{register = [[1], [16]], lane = [[0], [0], [2], [4], [8]], warp = [[0], [0]], block = []}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @cat_incompatible_target_keeps_convert + tt.func public @cat_incompatible_target_keeps_convert(%out: !tt.ptr) { + %lhs = arith.constant dense<0> : tensor<16xi32, #blocked> + %rhs = arith.constant dense<1> : tensor<16xi32, #blocked> + // CHECK: %[[CAT:[^ ]+]] = tt.cat + %cat = tt.cat %lhs, %rhs : tensor<16xi32, #blocked> -> tensor<32xi32, #blocked2> + // CHECK: %[[CVT:[^ ]+]] = ttg.convert_layout %[[CAT]] + %cvt = ttg.convert_layout %cat {allocation.offset = 0 : i32} : tensor<32xi32, #blocked2> -> tensor<32xi32, #linear> + %ptr = tt.splat %out : !tt.ptr -> tensor<32x!tt.ptr, #linear> + // CHECK: tt.store {{.*}}, %[[CVT]] + tt.store %ptr, %cvt : tensor<32x!tt.ptr, #linear> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#linear_bcast = #ttg.linear<{register = [[0]], lane = [[1], [2], [4], [8], [16]], warp = [[0], [0]], block = []}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @cat_target_adds_broadcasting_keeps_convert + tt.func public @cat_target_adds_broadcasting_keeps_convert(%out: !tt.ptr) { + %lhs = arith.constant dense<0> : tensor<16xi32, #blocked> + %rhs = arith.constant dense<1> : tensor<16xi32, #blocked> + // CHECK: %[[CAT:[^ ]+]] = tt.cat + %cat = tt.cat %lhs, %rhs : tensor<16xi32, #blocked> -> tensor<32xi32, #blocked2> + // CHECK: %[[CVT:[^ ]+]] = ttg.convert_layout %[[CAT]] + %cvt = ttg.convert_layout %cat : tensor<32xi32, #blocked2> -> tensor<32xi32, #linear_bcast> + %ptr = tt.splat %out : !tt.ptr -> tensor<32x!tt.ptr, #linear_bcast> + // CHECK: tt.store {{.*}}, %[[CVT]] + tt.store %ptr, %cvt : tensor<32x!tt.ptr, #linear_bcast> + tt.return + } +} diff --git a/test/TritonGPU/verify-blocked-layout.mlir b/test/TritonGPU/verify-blocked-layout.mlir index 9fd372b55056..91b90d2276fa 100644 --- a/test/TritonGPU/verify-blocked-layout.mlir +++ b/test/TritonGPU/verify-blocked-layout.mlir @@ -114,3 +114,17 @@ module attributes { tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#linear = #ttg.linear<{register = [[1], [16]], lane = [[0], [0], [2], [4], [8]], warp = [[0], [0]], block = []}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @invalid_cat_layout() { + %lhs = arith.constant dense<0> : tensor<16xi32, #blocked> + %rhs = arith.constant dense<1> : tensor<16xi32, #blocked> + // expected-error @+1 {{tt.cat result encoding requires 4 non-broadcast register values, but operands provide 2}} + %cat = tt.cat %lhs, %rhs : tensor<16xi32, #blocked> -> tensor<32xi32, #linear> + tt.return + } +}