Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,13 @@ class DialectInferLayoutInterface
// makes the reshape a "nop", i.e. the same GPU threads contain the same
// elements as before the reshape using legacy layouts. This is not always
// possible (in which case we fallback to using LinearLayouts)
// If allowReorder is set, an existing value in dstEnc is preferred when it
// still yields a non-expensive view.
// In the future we'll always use LinearLayouts
virtual LogicalResult
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
bool allowReorder,
std::optional<Location> loc) const = 0;

// Check if two layouts are structurally the same, even if their names are
Expand All @@ -87,6 +90,11 @@ class DialectInferLayoutInterface
verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA,
Attribute operandEncodingB) const = 0;

// Verify that the encodings are compatible to be used together in a cat
// operation.
virtual LogicalResult
verifyCatOpEncodingCompatibility(Operation *op) const = 0;

virtual LogicalResult
inferFp4ToFpOpEncoding(ArrayRef<int64_t> shape, int axis, Attribute inEnc,
Attribute &outEnc, bool fwdInference,
Expand Down
2 changes: 2 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,8 @@ def TT_CatOp : TT_Op<"cat", [NoMemoryEffect,
let results = (outs TT_Tensor:$result);

let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)";

let hasVerifier = 1;
}

def TT_JoinOp : TT_Op<"join", [
Expand Down
12 changes: 10 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,18 @@ SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor);
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kContig);

bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
// Return true if \p cat would be valid with result encoding \p targetEncoding.
bool isLegalCatEncoding(CatOp cat, Attribute targetEncoding);

// Return true if a view between the two types cannot be implemented as a no-op.
bool isExpensiveView(Type srcType, Type dstType);
bool isExpensiveView(ArrayRef<int64_t> srcShape, Attribute srcEncoding,
ArrayRef<int64_t> dstShape, Attribute dstEncoding);
inline bool isExpensiveView(Type srcType, Type dstType) {
auto tensorSrcType = cast<RankedTensorType>(srcType);
auto tensorDstType = cast<RankedTensorType>(dstType);
return isExpensiveView(tensorSrcType.getShape(), tensorSrcType.getEncoding(),
tensorDstType.getShape(), tensorDstType.getEncoding());
}

// Return a blocked encoding where the shape is distributed contiguously amongst
// the threads, warps, CTAs with 1 element per threads.
Expand Down
10 changes: 1 addition & 9 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,7 @@ struct CatOpConversion : public ConvertOpToLLVMPattern<CatOp> {
for (Value v : rhsVals)
retVals.push_back(v);

if (retVals.size() != strippedDstLayout.getInDimSize(kReg)) {
return op->emitError()
<< "tt.cat lowering expected "
<< strippedDstLayout.getInDimSize(kReg)
<< " (non-broadcast) register values for the result, but got "
<< retVals.size()
<< ". (hint: this usually means the operands/result encodings are "
"incompatible for the current CatOp lowering)";
}
assert(retVals.size() == strippedDstLayout.getInDimSize(kReg));

// Re-introduce broadcasting if the destination expects it.
if (!removeBroadcastDst.isIdentity())
Expand Down
6 changes: 5 additions & 1 deletion lib/Dialect/Gluon/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ struct GluonInferLayoutInterface : public triton::DialectInferLayoutInterface {
return success();
}

LogicalResult verifyCatOpEncodingCompatibility(Operation *op) const override {
return success();
}

LogicalResult
verifyLayoutsAreEqual(ArrayRef<int64_t> shape, Attribute expected,
Attribute got,
Expand All @@ -74,7 +78,7 @@ struct GluonInferLayoutInterface : public triton::DialectInferLayoutInterface {

LogicalResult
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc, bool,
std::optional<Location> loc) const override {
return inferAutoEncoding(srcEnc, dstEnc);
}
Expand Down
35 changes: 31 additions & 4 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,31 @@ OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) {
return foldViewLikeOp(*this, adaptor.getSrc());
}

//-- CatOp --
LogicalResult CatOp::verify() {
RankedTensorType lhsTy = getLhs().getType();
RankedTensorType resultTy = getType();

int64_t operandElements = lhsTy.getNumElements() * 2;
if (resultTy.getNumElements() != operandElements) {
return emitOpError("result element count must equal the sum of the "
"operand element counts, expected ")
<< operandElements << " but got " << resultTy.getNumElements();
}

Attribute operandEnc = lhsTy.getEncoding();
Attribute resultEnc = resultTy.getEncoding();
if (!!operandEnc != !!resultEnc) {
return emitOpError("requires that either (a) operands and result all have "
"encodings, or (b) none do.");
}
if (!resultEnc)
return success();

auto interface = cast<DialectInferLayoutInterface>(&resultEnc.getDialect());
return interface->verifyCatOpEncodingCompatibility(getOperation());
}

//-- ReshapeOp --

void ReshapeOp::build(OpBuilder &builder, OperationState &state,
Expand All @@ -870,9 +895,10 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &state,
auto srcEnc = srcTy.getEncoding();
Attribute dstEnc;
if (srcEnc) {
auto result = cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, shape,
dstEnc, state.location);
auto result =
cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, shape, dstEnc,
allowReorder, state.location);
assert(succeeded(result));
}
auto dstTy = RankedTensorType::get(shape, srcTy.getElementType(), dstEnc);
Expand Down Expand Up @@ -942,7 +968,8 @@ LogicalResult ReshapeOp::verify() {
auto layoutInterface =
cast<DialectInferLayoutInterface>(&srcEnc.getDialect());
auto result = layoutInterface->inferReshapeOpEncoding(
srcTy.getShape(), srcEnc, dstTy.getShape(), inferredDstEnc, getLoc());
srcTy.getShape(), srcEnc, dstTy.getShape(), inferredDstEnc,
/*allowReorder=*/false, getLoc());
if (failed(result))
return failure();
return layoutInterface->verifyLayoutsAreEqual(
Expand Down
63 changes: 47 additions & 16 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,10 @@ SmallVector<unsigned> getContigPerThread(RankedTensorType type) {
return toLinearEncoding(type).getContigPerThread();
}

bool isExpensiveView(Type srcType, Type dstType) {
auto tensorSrcType = cast<RankedTensorType>(srcType);
auto tensorDstType = cast<RankedTensorType>(dstType);
auto llSrc = toLinearLayout(tensorSrcType);
auto llDst = toLinearLayout(tensorDstType);
bool isExpensiveView(ArrayRef<int64_t> srcShape, Attribute srcEncoding,
ArrayRef<int64_t> dstShape, Attribute dstEncoding) {
auto llSrc = toLinearLayout(srcShape, srcEncoding);
auto llDst = toLinearLayout(dstShape, dstEncoding);
// In case there are replicated value we need to make sure the new and old
// layout have matching masks.
for (auto [srcMask, dstMask] :
Expand All @@ -127,7 +126,8 @@ bool isExpensiveView(Type srcType, Type dstType) {
if (srcMask.second != dstMask.second)
return true;
}
return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType);
return getTotalElemsPerThread(srcEncoding, srcShape) !=
getTotalElemsPerThread(dstEncoding, dstShape);
}

/* Utility function used by get.*Order methods of SliceEncodingAttr.
Expand Down Expand Up @@ -409,16 +409,27 @@ SmallVector<unsigned> orderPerDimImpl(const LinearLayout &ll,
return order.takeVector();
}

bool isExpensiveCat(CatOp cat, Attribute targetEncoding) {
// If the new elements per thread is less than the old one, we will need to
// do convert encoding that goes through shared memory anyway. So we
// consider it as expensive.
RankedTensorType tensorTy = cat.getType();
auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy);
auto shape = tensorTy.getShape();
auto newTotalElemsPerThread =
gpu::getTotalElemsPerThread(targetEncoding, shape);
return newTotalElemsPerThread < totalElemsPerThread;
static int64_t getNumNonBroadcastRegisters(ArrayRef<int64_t> shape,
Attribute encoding) {
auto kReg = StringAttr::get(encoding.getContext(), "register");
auto strippedLayout =
toLinearLayout(shape, encoding).removeZeroBasesAlongDim(kReg);
return strippedLayout.getInDimSize(kReg);
}

static int64_t getNumNonBroadcastRegisters(RankedTensorType tensorType) {
return getNumNonBroadcastRegisters(tensorType.getShape(),
tensorType.getEncoding());
}

bool isLegalCatEncoding(CatOp cat, Attribute targetEncoding) {
// Cat lowering concatenates the operands' unique register values. So the
// number of unique register values in the result must be equal to those in
// the operands.
int64_t operandRegs = getNumNonBroadcastRegisters(cat.getLhs().getType()) * 2;
int64_t resultRegs =
getNumNonBroadcastRegisters(cat.getType().getShape(), targetEncoding);
return resultRegs == operandRegs;
}

static LogicalResult
Expand Down Expand Up @@ -3017,6 +3028,20 @@ struct TritonGPUInferLayoutInterface
return success();
}

LogicalResult verifyCatOpEncodingCompatibility(Operation *op) const override {
auto cat = cast<CatOp>(op);
int64_t operandRegs =
getNumNonBroadcastRegisters(cat.getLhs().getType()) * 2;
int64_t resultRegs = getNumNonBroadcastRegisters(cat.getType());
if (resultRegs != operandRegs) {
return op->emitError("tt.cat result encoding requires ")
<< resultRegs
<< " non-broadcast register values, but operands provide "
<< operandRegs;
}
return success();
}

// Given a src shape + encoding and a dst shape, our goal is to compute a dst
// encoding that makes the reshape a "nop". That is, if GPU thread [x,y,z]
// contains elements [a,b,c,d] before the reshape, it contains those same
Expand Down Expand Up @@ -3284,11 +3309,17 @@ struct TritonGPUInferLayoutInterface
LogicalResult
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
bool allowReorder,
std::optional<Location> loc) const override {
if (product(srcShape) != product(dstShape)) {
return emitOptionalError(loc, "numel of dst shape does not match "
"numel of src shape");
}
// If allowReorder is true, there are multiple valid encodings. Prefer the
// hint if it is set and valid.
if (allowReorder && dstEnc)
if (!isExpensiveView(srcShape, srcEnc, dstShape, dstEnc))
return success();
auto result =
inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc);
if (succeeded(result)) {
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ struct CanonicalizeConvertFromConvert

// cvt(cat) -> cat
if (auto cat = dyn_cast<CatOp>(arg)) {
if (isExpensiveCat(cat, op.getType().getEncoding()))
if (!isLegalCatEncoding(cat, op.getType().getEncoding()))
return failure();

rewriter.replaceOpWithNewOp<CatOp>(op, op->getResult(0).getType(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,8 @@ bool canBeRemat(Operation *op) {
return false;
if (auto gather = dyn_cast<GatherOp>(op))
return !gather.getEfficientLayout();
if (auto reshape = dyn_cast<ReshapeOp>(op))
return !reshape.getEfficientLayout();

if (isa<scf::WhileOp, scf::ConditionOp>(op))
return false;
Expand Down
44 changes: 12 additions & 32 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,26 +465,22 @@ static Attribute inferSrcEncoding(triton::TransposeOpInterface op,
static Attribute inferReshapeOpDstEncoding(ArrayRef<int64_t> srcShape,
Attribute srcEnc,
ArrayRef<int64_t> dstShape,
bool allowReorder) {
// We don't do anything smart to allow-reorder reshapes here. They are
// handled in OptimizeThreadLocality.
if (allowReorder)
return {};

Attribute dstEnc;
Attribute dstEncHint = {},
bool allowReorder = false) {
Attribute dstEnc = dstEncHint;
auto result =
srcEnc.getDialect()
.getRegisteredInterface<triton::DialectInferLayoutInterface>()
->inferReshapeOpEncoding(srcShape, srcEnc, dstShape, dstEnc,
/*loc=*/std::nullopt);
allowReorder, /*loc=*/std::nullopt);
assert(succeeded(result));
return dstEnc;
}

static Attribute inferDstEncoding(triton::ReshapeOp op, Attribute encoding) {
return inferReshapeOpDstEncoding(op.getSrc().getType().getShape(), encoding,
op.getType().getShape(),
op.getAllowReorder());
return inferReshapeOpDstEncoding(
op.getSrc().getType().getShape(), encoding, op.getType().getShape(),
op.getType().getEncoding(), op.getAllowReorder());
}

static Attribute inferDstEncoding(GatherOp op, Attribute encoding) {
Expand All @@ -499,9 +495,9 @@ static Attribute inferSrcEncoding(triton::ReshapeOp op, Attribute encoding) {
// as the encoding of x given the encoding of y in `reshape(y) -> x`. It's an
// invariant of inferReshapeOpNoReorderEncoding that it's symmetric in this
// way.
return inferReshapeOpDstEncoding(op.getType().getShape(), encoding,
op.getSrc().getType().getShape(),
op.getAllowReorder());
return inferReshapeOpDstEncoding(
op.getType().getShape(), encoding, op.getSrc().getType().getShape(),
op.getSrc().getType().getEncoding(), op.getAllowReorder());
}

static bool isSingleValue(Value value) {
Expand Down Expand Up @@ -604,26 +600,10 @@ bool isExpensiveLoadOrStore(Operation *op) {
return true;
}

bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) {
if (!op)
return true;
if (isa<triton::LoadOp, triton::StoreOp>(op))
return isExpensiveLoadOrStore(op);
if (isa<triton::CatOp>(op))
return triton::gpu::isExpensiveCat(cast<triton::CatOp>(op), targetEncoding);
if (isa<triton::gpu::AsyncCopyGlobalToLocalOp, triton::AtomicRMWOp,
triton::AtomicCASOp, triton::DotOp>(op))
return true;
if (isa<scf::YieldOp, scf::ForOp, scf::IfOp, scf::WhileOp, scf::ConditionOp>(
op))
return true;
return false;
}

bool canUseResultEncoding(Operation *op, Attribute targetEncoding) {
if (isa<triton::CatOp>(op))
return !triton::gpu::isExpensiveCat(cast<triton::CatOp>(op),
targetEncoding);
return triton::gpu::isLegalCatEncoding(cast<triton::CatOp>(op),
targetEncoding);
if (auto convert = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
if (mlir::isa<triton::gpu::NvidiaMmaEncodingAttr>(targetEncoding)) {
auto srcEncoding = convert.getSrc().getType().getEncoding();
Expand Down
10 changes: 6 additions & 4 deletions test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -295,16 +295,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked3 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// COMMON-LABEL: join_cat_transitive_nonneg
tt.func @join_cat_transitive_nonneg(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
%0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked1>
%1 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32, #blocked1>
%2 = tt.join %0, %1 : tensor<8xi32, #blocked1> -> tensor<8x2xi32, #blocked>
%3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked1>
%4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked1>
%5 = tt.join %3, %4 : tensor<4xi32, #blocked1> -> tensor<4x2xi32, #blocked>
%6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked> -> tensor<8x2xi32, #blocked>
%3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked2>
%4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked2>
%5 = tt.join %3, %4 : tensor<4xi32, #blocked2> -> tensor<4x2xi32, #blocked3>
%6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked3> -> tensor<8x2xi32, #blocked>
%7 = arith.addi %2, %6 : tensor<8x2xi32, #blocked>
%zeros = arith.constant dense<0> : tensor<8x1xi32, #blocked>
%ones = arith.constant dense<1> : tensor<8x1xi32, #blocked>
Expand Down
10 changes: 6 additions & 4 deletions test/TritonGPU/amd/amd-convert-buffer-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -335,16 +335,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked3 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// COMMON-LABEL: join_cat_transitive_nonneg
tt.func @join_cat_transitive_nonneg(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
%0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked1>
%1 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32, #blocked1>
%2 = tt.join %0, %1 : tensor<8xi32, #blocked1> -> tensor<8x2xi32, #blocked>
%3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked1>
%4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked1>
%5 = tt.join %3, %4 : tensor<4xi32, #blocked1> -> tensor<4x2xi32, #blocked>
%6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked> -> tensor<8x2xi32, #blocked>
%3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked2>
%4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked2>
%5 = tt.join %3, %4 : tensor<4xi32, #blocked2> -> tensor<4x2xi32, #blocked3>
%6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked3> -> tensor<8x2xi32, #blocked>
%7 = arith.addi %2, %6 : tensor<8x2xi32, #blocked>
%zeros = arith.constant dense<0> : tensor<8x1xi32, #blocked>
%ones = arith.constant dense<1> : tensor<8x1xi32, #blocked>
Expand Down
Loading
Loading