Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions include/triton/Dialect/Triton/IR/Traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LogicalResult.h"

#include <iostream>
#include "triton/Dialect/Triton/IR/Types.h"

namespace mlir {
namespace OpTrait {
Expand Down Expand Up @@ -58,6 +57,51 @@ class VerifyTensorLayoutsTrait
}
};

// Verify if the op is a dot-like operation.
// A dot-like operation should have three operands.
// The first two operands should share a common dimension, and the result
// should have the dimensions of the two operands that are not shared.
// A dot-like operation can be either 2d or 3d.
// In the 3d case, the first dimension of operands is the batch dimension.
template <class ConcreteType>
class DotLike : public TraitBase<ConcreteType, DotLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
if (op->getNumOperands() != 3)
return op->emitOpError("expected 3 operands");
auto aTy = cast<TensorOrMemDesc>(op->getOperand(0).getType());
auto bTy = cast<TensorOrMemDesc>(op->getOperand(1).getType());
auto cTy = cast<TensorType>(op->getOperand(2).getType());
auto aShape = aTy.getShape();
auto bShape = bTy.getShape();
auto cShape = cTy.getShape();
// Check if all 3d or all 2d
if (aShape.size() != 2 && aShape.size() != 3)
return op->emitOpError("expected operands to be 2d or 3d");
if (aShape.size() != bShape.size() || aShape.size() != cShape.size())
return op->emitOpError("expected all operands to have the same rank");
// Check if the first two operands share a common dimension
if (aShape[aShape.size() - 1] != bShape[aShape.size() - 2])
return op->emitOpError("expected the last dimension of the first operand "
"to be equal to the second-to-last dimension of "
"the second operand");
// Check the batch dimension
if (aShape.size() == 3 &&
(aShape[0] != cShape[0] || bShape[0] != cShape[0]))
return op->emitOpError("expected the first dimension of the first "
"operand to be equal to the first dimension of "
"the result");
// Check the output shape
if (cShape[cShape.size() - 2] != aShape[aShape.size() - 2] ||
cShape[cShape.size() - 1] != bShape[aShape.size() - 1])
return op->emitOpError(
"expected the output shape to be the concatenation of the last "
"dimension of the first operand and the last dimension of the "
"second ");
return success();
}
};

template <typename ConcreteType>
class SameOperandsAndResultEncoding
: public TraitBase<ConcreteType, SameOperandsAndResultEncoding> {
Expand Down
1 change: 1 addition & 0 deletions include/triton/Dialect/Triton/IR/TritonInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ include "mlir/IR/OpBase.td"

def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">;
def DotLike : NativeOpTrait<"DotLike">;
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
def SameLoadStoreOperandsShape : NativeOpTrait<"SameLoadStoreOperandsShape">;
Expand Down
5 changes: 3 additions & 2 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> {
//
def TT_DotOp : TT_Op<"dot", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
DotLike,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "dot";
Expand All @@ -640,8 +641,8 @@ def TT_DotOp : TT_Op<"dot", [Pure,

let arguments = (
ins
TT_TensorOrMemDesc:$a,
TT_TensorOrMemDesc:$b,
TT_FpIntTensor:$a,
TT_FpIntTensor:$b,
TT_FpIntTensor:$c,
DefaultValuedAttr<TT_InputPrecisionAttr, "::mlir::triton::InputPrecision::IEEE">:$inputPrecision,
DefaultValuedAttr<I32Attr, "0">:$maxNumImpreciseAcc
Expand Down
3 changes: 2 additions & 1 deletion include/triton/Dialect/Triton/IR/TritonTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def TT_BoolTensor : RankedTensorOf<[I1]>;
def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>;

// Integer Type
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
def I4 : I<4>;
def TT_Int : AnyTypeOf<[I1, I4, I8, I16, I32, I64], "integer">;
def TT_IntTensor : RankedTensorOf<[TT_Int]>;
def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>;

Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def TTG_LocalDeallocOp : TTG_Op<"local_dealloc", [MemoryEffects<[MemFree<SharedM

Because we assume a memdesc is dead at the first point that post-dominates
its uses, ops that wait for an async operation on a memdesc to complete
(such as triton_nvidia_gpu.dot_wait) should also take the memdesc as an
(such as triton_nvidia_gpu.warp_group_dot_wait) should also take the memdesc as an
operand.
}];

Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class SharedEncodingAttr;

SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
const ArrayRef<int64_t> &shape,
TensorOrMemDesc type,
RankedTensorType type,
int numWarps);

/// Returns true if the Load uses block pointer.
Expand Down
20 changes: 11 additions & 9 deletions include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,14 @@ def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> {
}

//
// DotAsync Op
// WarpGroupDot Op
//
def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DotLike,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "dot async";
let summary = "warp group dot";

let description = [{
$d = matrix_multiply($a, $b) + $c. For docs on InputPrecisionAttr, see TT_DotOp
Expand All @@ -82,25 +83,26 @@ def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [DeclareOpInterfaceMethods<InferTypeO
let arguments = (ins TT_TensorOrMemDesc:$a,
TT_TensorOrMemDesc:$b,
TT_FpIntTensor:$c,
TT_InputPrecisionAttr:$inputPrecision,
I32Attr:$maxNumImpreciseAcc);
DefaultValuedAttr<TT_InputPrecisionAttr, "::mlir::triton::InputPrecision::IEEE">:$inputPrecision,
DefaultValuedAttr<I32Attr, "0">:$maxNumImpreciseAcc,
DefaultValuedAttr<BoolAttr, "false">:$isAsync);

let results = (outs TT_FpIntTensor:$d);

let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
}

def TTNG_DotWaitOp : TTNG_Op<"dot_wait", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
AllTypesMatch<["inputs", "outputs"]>]> {
let summary = "dot wait";
def TTNG_WarpGroupDotWaitOp : TTNG_Op<"warp_group_dot_wait", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
AllTypesMatch<["inputs", "outputs"]>]> {
let summary = "warp group dot wait";
let arguments = (ins Variadic<TT_TensorOrMemDesc>:$inputs, I32Attr:$pendings);
let results = (outs Variadic<TT_TensorOrMemDesc>:$outputs);
let description = [{
Waits until there are $pendings or fewer outstanding async dot operations.

$inputs must be the tensors corresponding to the async dot ops that we're
waiting on. For example, if there are N pending async dot ops and we call
`dot_wait 1`, then $inputs must be the result of the first dot op.
`warp_group_dot_wait 1`, then $inputs must be the result of the first dot op.
}];

let assemblyFormat = "$inputs attr-dict `:` type($inputs)";
Expand Down
11 changes: 0 additions & 11 deletions lib/Analysis/Membar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,6 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
}
}
}
// XXX(Keren): This is a hack as we cannot set side effects for dot ops, but
// on hopper they do have side effects. Need to clean it up
if (auto dotOp = dyn_cast<triton::DotOp>(op)) {
for (auto value : dotOp.getOperands()) {
for (auto bufferId : allocation->getBufferIds(value)) {
if (bufferId != Allocation::InvalidBufferId)
curBlockInfo.syncReadIntervals.insert(
allocation->getAllocatedInterval(bufferId));
}
}
}
// Scratch buffer is considered as both shared memory write & read
auto bufferId = allocation->getBufferId(op);
if (bufferId != Allocation::InvalidBufferId) {
Expand Down
5 changes: 3 additions & 2 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2741,8 +2741,9 @@ struct CanonicalizeConvertFromConvert
// for hopper MMAv3
if (mlir::isa<SharedEncodingAttr>(dstType.getEncoding()) &&
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
llvm::any_of(op.getResult().getUsers(),
[](Operation *dot) { return isa<DotOp>(dot); })) {
llvm::any_of(op.getResult().getUsers(), [](Operation *dot) {
return dot->hasTrait<OpTrait::DotLike>();
})) {
return failure();
}

Expand Down
14 changes: 8 additions & 6 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,14 @@ class BlockedToMMA : public mlir::RewritePattern {
auto newAcc =
rewriter.create<ConvertLayoutOp>(oldAcc.getLoc(), newRetType, oldAcc);

Operation *newDot = nullptr;
if (versionMajor == 3) {
a = getMMAv3Operand(a, rewriter, 0);
b = getMMAv3Operand(b, rewriter, 1);
newDot = rewriter.create<triton::nvidia_gpu::WarpGroupDotOp>(
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getInputPrecision(),
dotOp.getMaxNumImpreciseAcc(), false);
} else {

// convert operands
int minBitwidth =
std::min(computeOrigBitWidth(a), computeOrigBitWidth(b));
Expand All @@ -322,14 +325,13 @@ class BlockedToMMA : public mlir::RewritePattern {
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(), newBEncoding);
b = rewriter.create<ConvertLayoutOp>(b.getLoc(), newBType, b);
newDot = rewriter.create<DotOp>(dotOp.getLoc(), newRetType, a, b, newAcc,
dotOp.getInputPrecision(),
dotOp.getMaxNumImpreciseAcc());
}
// convert dot instruction
auto newDot = rewriter.create<DotOp>(dotOp.getLoc(), newRetType, a, b,
newAcc, dotOp.getInputPrecision(),
dotOp.getMaxNumImpreciseAcc());

rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, oldRetType,
newDot.getResult());
newDot->getResult(0));
return success();
}
};
Expand Down
7 changes: 4 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class FuseTransHopper : public OpRewritePattern<LocalAllocOp> {
LogicalResult matchAndRewrite(LocalAllocOp allocOp,
PatternRewriter &rewriter) const override {
if (!allocOp->hasOneUse() ||
!isa<DotOp, nvidia_gpu::DotAsyncOp>(*allocOp->getUsers().begin()))
!allocOp->getUsers().begin()->hasTrait<OpTrait::DotLike>())
return failure();

auto dot = *allocOp->getUsers().begin();
Expand Down Expand Up @@ -268,10 +268,11 @@ class FuseTransHopper : public OpRewritePattern<LocalAllocOp> {
// dot(convert(lhs #mma) #shared, rhs) #mma ->
// dot(convert(lhs #mma) #dot_operand, rhs) #mma,
// for fp16 or bf16 MMAv3 dots.
struct MMAV3UseRegOperand : public OpRewritePattern<DotOp> {
struct MMAV3UseRegOperand
: public OpRewritePattern<triton::nvidia_gpu::WarpGroupDotOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(DotOp dotOp,
LogicalResult matchAndRewrite(triton::nvidia_gpu::WarpGroupDotOp dotOp,
PatternRewriter &rewriter) const override {
auto alloc = dotOp.getOperand(0).getDefiningOp<LocalAllocOp>();
if (!alloc || !alloc.getSrc())
Expand Down
Loading