diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index 545b7709d9c5..f9e0fa6c6a09 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -59,10 +59,13 @@ class DialectInferLayoutInterface // makes the reshape a "nop", i.e. the same GPU threads contain the same // elements as before the reshape using legacy layouts. This is not always // possible (in which case we fallback to using LinearLayouts) + // If allowReorder is set, an existing value in dstEnc is preferred when it + // still yields a non-expensive view. // In the future we'll always use LinearLayouts virtual LogicalResult inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, ArrayRef dstShape, Attribute &dstEnc, + bool allowReorder, std::optional loc) const = 0; // Check if two layouts are structurally the same, even if their names are @@ -87,6 +90,11 @@ class DialectInferLayoutInterface verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, Attribute operandEncodingB) const = 0; + // Verify that the encodings are compatible to be used together in a cat + // operation. + virtual LogicalResult + verifyCatOpEncodingCompatibility(Operation *op) const = 0; + virtual LogicalResult inferFp4ToFpOpEncoding(ArrayRef shape, int axis, Attribute inEnc, Attribute &outEnc, bool fwdInference, diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 5c637e3b8368..16e5b3e966a7 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -491,6 +491,8 @@ def TT_CatOp : TT_Op<"cat", [NoMemoryEffect, let results = (outs TT_Tensor:$result); let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; + + let hasVerifier = 1; } def TT_JoinOp : TT_Op<"join", [ diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 0a7866bf3efe..fce1f71cb2fb 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -260,10 +260,18 @@ SmallVector getMatrixOrder(unsigned rank, bool rowMajor); SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, bool kContig); -bool isExpensiveCat(CatOp cat, Attribute targetEncoding); +// Return true if \p cat would be valid with result encoding \p targetEncoding. +bool isLegalCatEncoding(CatOp cat, Attribute targetEncoding); // Return true if a view between the two types cannot be implemented as a no-op. -bool isExpensiveView(Type srcType, Type dstType); +bool isExpensiveView(ArrayRef srcShape, Attribute srcEncoding, + ArrayRef dstShape, Attribute dstEncoding); +inline bool isExpensiveView(Type srcType, Type dstType) { + auto tensorSrcType = cast(srcType); + auto tensorDstType = cast(dstType); + return isExpensiveView(tensorSrcType.getShape(), tensorSrcType.getEncoding(), + tensorDstType.getShape(), tensorDstType.getEncoding()); +} // Return a blocked encoding where the shape is distributed contiguously amongst // the threads, warps, CTAs with 1 element per threads. diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 90a472038e7d..bda1d54aea7f 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -169,15 +169,7 @@ struct CatOpConversion : public ConvertOpToLLVMPattern { for (Value v : rhsVals) retVals.push_back(v); - if (retVals.size() != strippedDstLayout.getInDimSize(kReg)) { - return op->emitError() - << "tt.cat lowering expected " - << strippedDstLayout.getInDimSize(kReg) - << " (non-broadcast) register values for the result, but got " - << retVals.size() - << ". (hint: this usually means the operands/result encodings are " - "incompatible for the current CatOp lowering)"; - } + assert(retVals.size() == strippedDstLayout.getInDimSize(kReg)); // Re-introduce broadcasting if the destination expects it. if (!removeBroadcastDst.isIdentity()) diff --git a/lib/Dialect/Gluon/IR/Dialect.cpp b/lib/Dialect/Gluon/IR/Dialect.cpp index 0a18ec8522f7..74c3b3fa84e0 100644 --- a/lib/Dialect/Gluon/IR/Dialect.cpp +++ b/lib/Dialect/Gluon/IR/Dialect.cpp @@ -65,6 +65,10 @@ struct GluonInferLayoutInterface : public triton::DialectInferLayoutInterface { return success(); } + LogicalResult verifyCatOpEncodingCompatibility(Operation *op) const override { + return success(); + } + LogicalResult verifyLayoutsAreEqual(ArrayRef shape, Attribute expected, Attribute got, @@ -74,7 +78,7 @@ struct GluonInferLayoutInterface : public triton::DialectInferLayoutInterface { LogicalResult inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, - ArrayRef dstShape, Attribute &dstEnc, + ArrayRef dstShape, Attribute &dstEnc, bool, std::optional loc) const override { return inferAutoEncoding(srcEnc, dstEnc); } diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index b7f61e6d4002..a3e6a7c2c3df 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -862,6 +862,31 @@ OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) { return foldViewLikeOp(*this, adaptor.getSrc()); } +//-- CatOp -- +LogicalResult CatOp::verify() { + RankedTensorType lhsTy = getLhs().getType(); + RankedTensorType resultTy = getType(); + + int64_t operandElements = lhsTy.getNumElements() * 2; + if (resultTy.getNumElements() != operandElements) { + return emitOpError("result element count must equal the sum of the " + "operand element counts, expected ") + << operandElements << " but got " << resultTy.getNumElements(); + } + + Attribute operandEnc = lhsTy.getEncoding(); + Attribute resultEnc = resultTy.getEncoding(); + if (!!operandEnc != !!resultEnc) { + return emitOpError("requires that either (a) operands and result all have " + "encodings, or (b) none do."); + } + if (!resultEnc) + return success(); + + auto interface = cast(&resultEnc.getDialect()); + return interface->verifyCatOpEncodingCompatibility(getOperation()); +} + //-- ReshapeOp -- void ReshapeOp::build(OpBuilder &builder, OperationState &state, @@ -870,9 +895,10 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &state, auto srcEnc = srcTy.getEncoding(); Attribute dstEnc; if (srcEnc) { - auto result = cast(&srcEnc.getDialect()) - ->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, shape, - dstEnc, state.location); + auto result = + cast(&srcEnc.getDialect()) + ->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, shape, dstEnc, + allowReorder, state.location); assert(succeeded(result)); } auto dstTy = RankedTensorType::get(shape, srcTy.getElementType(), dstEnc); @@ -942,7 +968,8 @@ LogicalResult ReshapeOp::verify() { auto layoutInterface = cast(&srcEnc.getDialect()); auto result = layoutInterface->inferReshapeOpEncoding( - srcTy.getShape(), srcEnc, dstTy.getShape(), inferredDstEnc, getLoc()); + srcTy.getShape(), srcEnc, dstTy.getShape(), inferredDstEnc, + /*allowReorder=*/false, getLoc()); if (failed(result)) return failure(); return layoutInterface->verifyLayoutsAreEqual( diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 102eaa39697f..974e09792e48 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -114,11 +114,10 @@ SmallVector getContigPerThread(RankedTensorType type) { return toLinearEncoding(type).getContigPerThread(); } -bool isExpensiveView(Type srcType, Type dstType) { - auto tensorSrcType = cast(srcType); - auto tensorDstType = cast(dstType); - auto llSrc = toLinearLayout(tensorSrcType); - auto llDst = toLinearLayout(tensorDstType); +bool isExpensiveView(ArrayRef srcShape, Attribute srcEncoding, + ArrayRef dstShape, Attribute dstEncoding) { + auto llSrc = toLinearLayout(srcShape, srcEncoding); + auto llDst = toLinearLayout(dstShape, dstEncoding); // In case there are replicated value we need to make sure the new and old // layout have matching masks. for (auto [srcMask, dstMask] : @@ -127,7 +126,8 @@ bool isExpensiveView(Type srcType, Type dstType) { if (srcMask.second != dstMask.second) return true; } - return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType); + return getTotalElemsPerThread(srcEncoding, srcShape) != + getTotalElemsPerThread(dstEncoding, dstShape); } /* Utility function used by get.*Order methods of SliceEncodingAttr. @@ -409,16 +409,27 @@ SmallVector orderPerDimImpl(const LinearLayout &ll, return order.takeVector(); } -bool isExpensiveCat(CatOp cat, Attribute targetEncoding) { - // If the new elements per thread is less than the old one, we will need to - // do convert encoding that goes through shared memory anyway. So we - // consider it as expensive. - RankedTensorType tensorTy = cat.getType(); - auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy); - auto shape = tensorTy.getShape(); - auto newTotalElemsPerThread = - gpu::getTotalElemsPerThread(targetEncoding, shape); - return newTotalElemsPerThread < totalElemsPerThread; +static int64_t getNumNonBroadcastRegisters(ArrayRef shape, + Attribute encoding) { + auto kReg = StringAttr::get(encoding.getContext(), "register"); + auto strippedLayout = + toLinearLayout(shape, encoding).removeZeroBasesAlongDim(kReg); + return strippedLayout.getInDimSize(kReg); +} + +static int64_t getNumNonBroadcastRegisters(RankedTensorType tensorType) { + return getNumNonBroadcastRegisters(tensorType.getShape(), + tensorType.getEncoding()); +} + +bool isLegalCatEncoding(CatOp cat, Attribute targetEncoding) { + // Cat lowering concatenates the operands' unique register values. So the + // number of unique register values in the result must be equal to those in + // the operands. + int64_t operandRegs = getNumNonBroadcastRegisters(cat.getLhs().getType()) * 2; + int64_t resultRegs = + getNumNonBroadcastRegisters(cat.getType().getShape(), targetEncoding); + return resultRegs == operandRegs; } static LogicalResult @@ -3017,6 +3028,20 @@ struct TritonGPUInferLayoutInterface return success(); } + LogicalResult verifyCatOpEncodingCompatibility(Operation *op) const override { + auto cat = cast(op); + int64_t operandRegs = + getNumNonBroadcastRegisters(cat.getLhs().getType()) * 2; + int64_t resultRegs = getNumNonBroadcastRegisters(cat.getType()); + if (resultRegs != operandRegs) { + return op->emitError("tt.cat result encoding requires ") + << resultRegs + << " non-broadcast register values, but operands provide " + << operandRegs; + } + return success(); + } + // Given a src shape + encoding and a dst shape, our goal is to compute a dst // encoding that makes the reshape a "nop". That is, if GPU thread [x,y,z] // contains elements [a,b,c,d] before the reshape, it contains those same @@ -3284,11 +3309,17 @@ struct TritonGPUInferLayoutInterface LogicalResult inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, ArrayRef dstShape, Attribute &dstEnc, + bool allowReorder, std::optional loc) const override { if (product(srcShape) != product(dstShape)) { return emitOptionalError(loc, "numel of dst shape does not match " "numel of src shape"); } + // If allowReorder is true, there are multiple valid encodings. Prefer the + // hint if it is set and valid. + if (allowReorder && dstEnc) + if (!isExpensiveView(srcShape, srcEnc, dstShape, dstEnc)) + return success(); auto result = inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc); if (succeeded(result)) { diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 0cd5bd71f31a..2e5125c3c377 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -355,7 +355,7 @@ struct CanonicalizeConvertFromConvert // cvt(cat) -> cat if (auto cat = dyn_cast(arg)) { - if (isExpensiveCat(cat, op.getType().getEncoding())) + if (!isLegalCatEncoding(cat, op.getType().getEncoding())) return failure(); rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index d2302dcabed1..5b7a04123ffa 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -657,6 +657,8 @@ bool canBeRemat(Operation *op) { return false; if (auto gather = dyn_cast(op)) return !gather.getEfficientLayout(); + if (auto reshape = dyn_cast(op)) + return !reshape.getEfficientLayout(); if (isa(op)) return false; diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index f8fe3361ce0d..edc8c8ff57e8 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -465,26 +465,22 @@ static Attribute inferSrcEncoding(triton::TransposeOpInterface op, static Attribute inferReshapeOpDstEncoding(ArrayRef srcShape, Attribute srcEnc, ArrayRef dstShape, - bool allowReorder) { - // We don't do anything smart to allow-reorder reshapes here. They are - // handled in OptimizeThreadLocality. - if (allowReorder) - return {}; - - Attribute dstEnc; + Attribute dstEncHint = {}, + bool allowReorder = false) { + Attribute dstEnc = dstEncHint; auto result = srcEnc.getDialect() .getRegisteredInterface() ->inferReshapeOpEncoding(srcShape, srcEnc, dstShape, dstEnc, - /*loc=*/std::nullopt); + allowReorder, /*loc=*/std::nullopt); assert(succeeded(result)); return dstEnc; } static Attribute inferDstEncoding(triton::ReshapeOp op, Attribute encoding) { - return inferReshapeOpDstEncoding(op.getSrc().getType().getShape(), encoding, - op.getType().getShape(), - op.getAllowReorder()); + return inferReshapeOpDstEncoding( + op.getSrc().getType().getShape(), encoding, op.getType().getShape(), + op.getType().getEncoding(), op.getAllowReorder()); } static Attribute inferDstEncoding(GatherOp op, Attribute encoding) { @@ -499,9 +495,9 @@ static Attribute inferSrcEncoding(triton::ReshapeOp op, Attribute encoding) { // as the encoding of x given the encoding of y in `reshape(y) -> x`. It's an // invariant of inferReshapeOpNoReorderEncoding that it's symmetric in this // way. - return inferReshapeOpDstEncoding(op.getType().getShape(), encoding, - op.getSrc().getType().getShape(), - op.getAllowReorder()); + return inferReshapeOpDstEncoding( + op.getType().getShape(), encoding, op.getSrc().getType().getShape(), + op.getSrc().getType().getEncoding(), op.getAllowReorder()); } static bool isSingleValue(Value value) { @@ -604,26 +600,10 @@ bool isExpensiveLoadOrStore(Operation *op) { return true; } -bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { - if (!op) - return true; - if (isa(op)) - return isExpensiveLoadOrStore(op); - if (isa(op)) - return triton::gpu::isExpensiveCat(cast(op), targetEncoding); - if (isa(op)) - return true; - if (isa( - op)) - return true; - return false; -} - bool canUseResultEncoding(Operation *op, Attribute targetEncoding) { if (isa(op)) - return !triton::gpu::isExpensiveCat(cast(op), - targetEncoding); + return triton::gpu::isLegalCatEncoding(cast(op), + targetEncoding); if (auto convert = dyn_cast(op)) { if (mlir::isa(targetEncoding)) { auto srcEncoding = convert.getSrc().getType().getEncoding(); diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir index a2de394abc2d..a2da914eba7d 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir @@ -295,16 +295,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON-LABEL: join_cat_transitive_nonneg tt.func @join_cat_transitive_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked1> %1 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32, #blocked1> %2 = tt.join %0, %1 : tensor<8xi32, #blocked1> -> tensor<8x2xi32, #blocked> - %3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked1> - %4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked1> - %5 = tt.join %3, %4 : tensor<4xi32, #blocked1> -> tensor<4x2xi32, #blocked> - %6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked> -> tensor<8x2xi32, #blocked> + %3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked2> + %4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked2> + %5 = tt.join %3, %4 : tensor<4xi32, #blocked2> -> tensor<4x2xi32, #blocked3> + %6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked3> -> tensor<8x2xi32, #blocked> %7 = arith.addi %2, %6 : tensor<8x2xi32, #blocked> %zeros = arith.constant dense<0> : tensor<8x1xi32, #blocked> %ones = arith.constant dense<1> : tensor<8x1xi32, #blocked> diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir index 94413edc06d1..922717d546b7 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -335,16 +335,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON-LABEL: join_cat_transitive_nonneg tt.func @join_cat_transitive_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked1> %1 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32, #blocked1> %2 = tt.join %0, %1 : tensor<8xi32, #blocked1> -> tensor<8x2xi32, #blocked> - %3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked1> - %4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked1> - %5 = tt.join %3, %4 : tensor<4xi32, #blocked1> -> tensor<4x2xi32, #blocked> - %6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked> -> tensor<8x2xi32, #blocked> + %3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked2> + %4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked2> + %5 = tt.join %3, %4 : tensor<4xi32, #blocked2> -> tensor<4x2xi32, #blocked3> + %6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked3> -> tensor<8x2xi32, #blocked> %7 = arith.addi %2, %6 : tensor<8x2xi32, #blocked> %zeros = arith.constant dense<0> : tensor<8x1xi32, #blocked> %ones = arith.constant dense<1> : tensor<8x1xi32, #blocked> diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 602b0886a111..dade5df4ac3f 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2217,6 +2217,41 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr // ----- +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.slice<{dim = 0, parent = #blocked}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: @permuting_reshape_backward_remat + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.return + tt.func public @permuting_reshape_backward_remat(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<8x2xi32, #blocked3> { + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked1> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr, #blocked1> + %2 = tt.addptr %1, %0 : tensor<16x!tt.ptr, #blocked1>, tensor<16xi32, #blocked1> + %3 = tt.load %2 : tensor<16x!tt.ptr, #blocked1> + %4 = tt.reshape %3 allow_reorder : tensor<16xi32, #blocked1> -> tensor<8x2xi32, #blocked4> + %5 = ttg.convert_layout %4 : tensor<8x2xi32, #blocked4> -> tensor<8x2xi32, #blocked3> + tt.return %5 : tensor<8x2xi32, #blocked3> + } + + // CHECK-LABEL: @permuting_reshape_no_backward_remat_efficient_layout + // CHECK: ttg.convert_layout + // CHECK: tt.return + tt.func public @permuting_reshape_no_backward_remat_efficient_layout(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<8x2xi32, #blocked3> { + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked1> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr, #blocked1> + %2 = tt.addptr %1, %0 : tensor<16x!tt.ptr, #blocked1>, tensor<16xi32, #blocked1> + %3 = tt.load %2 : tensor<16x!tt.ptr, #blocked1> + %4 = tt.reshape %3 allow_reorder efficient_layout : tensor<16xi32, #blocked1> -> tensor<8x2xi32, #blocked4> + %5 = ttg.convert_layout %4 : tensor<8x2xi32, #blocked4> -> tensor<8x2xi32, #blocked3> + tt.return %5 : tensor<8x2xi32, #blocked3> + } +} + +// ----- + #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #slice1dim1 = #ttg.slice<{dim = 1, parent = #blocked1}> #blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> @@ -4176,3 +4211,47 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 1 : i32} tt.return %o : tensor<1x2x2xi32, #dst> } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#linear = #ttg.linear<{register = [[1], [16]], lane = [[0], [0], [2], [4], [8]], warp = [[0], [0]], block = []}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @cat_incompatible_target_keeps_convert + tt.func public @cat_incompatible_target_keeps_convert(%out: !tt.ptr) { + %lhs = arith.constant dense<0> : tensor<16xi32, #blocked> + %rhs = arith.constant dense<1> : tensor<16xi32, #blocked> + // CHECK: %[[CAT:[^ ]+]] = tt.cat + %cat = tt.cat %lhs, %rhs : tensor<16xi32, #blocked> -> tensor<32xi32, #blocked2> + // CHECK: %[[CVT:[^ ]+]] = ttg.convert_layout %[[CAT]] + %cvt = ttg.convert_layout %cat {allocation.offset = 0 : i32} : tensor<32xi32, #blocked2> -> tensor<32xi32, #linear> + %ptr = tt.splat %out : !tt.ptr -> tensor<32x!tt.ptr, #linear> + // CHECK: tt.store {{.*}}, %[[CVT]] + tt.store %ptr, %cvt : tensor<32x!tt.ptr, #linear> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#linear_bcast = #ttg.linear<{register = [[0]], lane = [[1], [2], [4], [8], [16]], warp = [[0], [0]], block = []}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @cat_target_adds_broadcasting_keeps_convert + tt.func public @cat_target_adds_broadcasting_keeps_convert(%out: !tt.ptr) { + %lhs = arith.constant dense<0> : tensor<16xi32, #blocked> + %rhs = arith.constant dense<1> : tensor<16xi32, #blocked> + // CHECK: %[[CAT:[^ ]+]] = tt.cat + %cat = tt.cat %lhs, %rhs : tensor<16xi32, #blocked> -> tensor<32xi32, #blocked2> + // CHECK: %[[CVT:[^ ]+]] = ttg.convert_layout %[[CAT]] + %cvt = ttg.convert_layout %cat : tensor<32xi32, #blocked2> -> tensor<32xi32, #linear_bcast> + %ptr = tt.splat %out : !tt.ptr -> tensor<32x!tt.ptr, #linear_bcast> + // CHECK: tt.store {{.*}}, %[[CVT]] + tt.store %ptr, %cvt : tensor<32x!tt.ptr, #linear_bcast> + tt.return + } +} diff --git a/test/TritonGPU/verify-blocked-layout.mlir b/test/TritonGPU/verify-blocked-layout.mlir index 9fd372b55056..91b90d2276fa 100644 --- a/test/TritonGPU/verify-blocked-layout.mlir +++ b/test/TritonGPU/verify-blocked-layout.mlir @@ -114,3 +114,17 @@ module attributes { tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#linear = #ttg.linear<{register = [[1], [16]], lane = [[0], [0], [2], [4], [8]], warp = [[0], [0]], block = []}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @invalid_cat_layout() { + %lhs = arith.constant dense<0> : tensor<16xi32, #blocked> + %rhs = arith.constant dense<1> : tensor<16xi32, #blocked> + // expected-error @+1 {{tt.cat result encoding requires 4 non-broadcast register values, but operands provide 2}} + %cat = tt.cat %lhs, %rhs : tensor<16xi32, #blocked> -> tensor<32xi32, #linear> + tt.return + } +} diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index 34c51cc51caf..ae5932817171 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -139,7 +139,7 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, ctx, [&](Diagnostic &diag) { diags.push_back(" - " + diag.str()); }); result = inferLayout->inferReshapeOpEncoding( srcTy.getShape(), srcTy.getEncoding(), dstTy.getShape(), inferredEnc, - UnknownLoc::get(ctx)); + /*allowReorder=*/false, UnknownLoc::get(ctx)); } // We expect the reshape to succeed as long as the inputs have the same @@ -164,7 +164,7 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, Attribute inferredSrcEnc; auto result = inferLayout->inferReshapeOpEncoding( dstTy.getShape(), inferredEnc, srcTy.getShape(), inferredSrcEnc, - UnknownLoc::get(ctx)); + /*allowReorder=*/false, UnknownLoc::get(ctx)); EXPECT_TRUE(succeeded(result)) << "Inverse encoding inference (" << triton::join(dstTy.getShape(), "x") << " " << stringifyLLVMType(inferredEnc) << " -> " @@ -439,7 +439,8 @@ TEST_F(JoinOpTest, JoinOpLayoutPropagation) { } Attribute reshapedEnc; result = inferLayout->inferReshapeOpEncoding( - transShape, transEnc, newShape, reshapedEnc, std::nullopt); + transShape, transEnc, newShape, reshapedEnc, + /*allowReorder=*/false, std::nullopt); assert(succeeded(result)); // The layouts should be structurally the same // but reshapeEnc will likely be a LinearEncodingAttr