Skip to content
Merged
57 changes: 34 additions & 23 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -781,22 +781,24 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {

InterfaceMethod<"Return shape per CTA.",
"SmallVector<unsigned>",
"getShapePerCTATileForDotOperands",
"getShapePerCTATileForOperand",
(ins "ArrayRef<int64_t>":$tensorShape,
"unsigned":$opIdx)>,
"int":$kWidth,
"int":$opIdx)>,

InterfaceMethod<"Return total element size per thread for dot operands.",
"unsigned",
"getTotalElemsPerThreadForOperands",
"getTotalElemsPerThreadForOperand",
(ins "ArrayRef<int64_t>":$tensorShape,
"Type":$eltTy,
"unsigned":$kWidth,
"unsigned":$opIdx)>,
"int":$kWidth,
"int":$opIdx)>,

InterfaceMethod<"Return size per thread for dot operands.",
"SmallVector<unsigned>",
"getSizePerThreadForOperands",
(ins "unsigned":$opIdx)>,
"getSizePerThreadForOperand",
(ins "int":$opIdx,
"int":$kWidth)>,
];
}

Expand Down Expand Up @@ -914,11 +916,11 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
bool supportReduction() const {
return true;
}
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
SmallVector<int64_t> getMFMAInstrShapeForOperands(int kWidth, int opIdx) const;
SmallVector<int64_t> getMFMARepForOperands(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;

SmallVector<unsigned> getContigPerThread() {
auto rank = getWarpsPerCTA().size();
Expand Down Expand Up @@ -1021,12 +1023,12 @@ Row | warp 0 warp 2
bool supportReduction() const {
return true;
}
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
SmallVector<int64_t> getElemsPerInstrForOperands() const;
SmallVector<int64_t> getRepForOperands(ArrayRef<int64_t> operandShape,
Type elemType, int kWidth, int opIdx) const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
Type elemType, int kWidth, int opIdx) const;
static SmallVector<unsigned> getMNKDimPerInstr();

SmallVector<unsigned> getContigPerThread() {
Expand Down Expand Up @@ -1222,18 +1224,18 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
SmallVector<int> getMMAv1Rep(int opIdx) const;
SmallVector<int> getMMAv1ShapePerWarp(int opIdx) const;
int getMMAv1Vec(int opIdx) const;
SmallVector<int64_t> getMMAv2Rep(ArrayRef<int64_t> shape,
int bitwidth, int opIdx) const;
SmallVector<int64_t> getMMAv2RepForOperand(ArrayRef<int64_t> shape,
int bitwidth, int kWidth, int opIdx) const;

bool supportReduction() const {
if (isAmpere() || isHopper()) {
return true;
}
return false;
};
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;

SmallVector<unsigned> getContigPerThread() {
assert(isVolta() || isAmpere() || isHopper());
Expand Down Expand Up @@ -1344,7 +1346,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim
let genVerifyDecl = 1;
let extraClassDeclaration = extraDistributedDeclaration # [{
SmallVector<unsigned> getContigPerThread() {
return getSizePerThread();
auto rank = getWarpsPerCTA().size();
assert(rank == 2 || rank == 3);
SmallVector<unsigned> contigPerThread(rank, 1);
auto kWidth = getKWidth();
assert(kWidth != 0 && "Do not support kWidth=0");
Comment thread
lezcano marked this conversation as resolved.
if (getOpIdx() == 0)
contigPerThread[rank - 1] = kWidth;
else
contigPerThread[rank - 2] = kWidth;
return contigPerThread;
};
}];
}
Expand Down
Loading