Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@ compared to 1*64 when the hasLeadingOffset is false.
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,
"unsigned":$typeWidthInBit), [{
bool needTrans = false; // default value
return get(context, dotOpEnc, shape, order, CTALayout, typeWidthInBit, needTrans);
}]>,

AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
"ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,
"unsigned":$typeWidthInBit,
"bool":$needTrans), [{
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();

if(!mmaEnc)
Expand Down Expand Up @@ -152,16 +162,23 @@ compared to 1*64 when the hasLeadingOffset is false.

// --- handle A operand ---
if (opIdx == 0) { // compute swizzling for A operand
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
int m = (needTrans) ? matShape[2] : matShape[0];
int k = (needTrans) ? matShape[0] : matShape[2];
int vec = (order[0] == 1) ? k : m;
int mmaStride = (order[0] == 1) ? m : k;
int maxPhase = mmaStride / perPhase;
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}

// --- handle B operand ---
if (opIdx == 1) {
int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
// we compute vec and maxPhase m, n and k size of the mma
// instruction. when matmul operands is transposed, we should
// consider that to get m, n and k.
int n = needTrans ? matShape[2] : matShape[1];
int k = needTrans ? matShape[1] : matShape[2];
int vec = (order[0] == 1) ? n : k;
int mmaStride = (order[0] == 1) ? k : n;
int maxPhase = mmaStride / perPhase;
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}
Expand Down Expand Up @@ -189,6 +206,16 @@ compared to 1*64 when the hasLeadingOffset is false.
return get(context, dotOpEnc, shape, order, CTALayout, bitwidth);
}]>,

AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
"ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,
"Type":$eltTy,
"bool":$needTrans), [{
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
return get(context, dotOpEnc, shape, order, CTALayout, bitwidth, needTrans);
}]>,

AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,
Expand Down
7 changes: 6 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,14 @@ class ConvertTransConvert : public mlir::RewritePattern {
// used here. For tests where numCTAs = 1, this is not a problem since all
// CTALayouts are the same.
auto newXOrder = triton::gpu::getOrder(argEncoding);
// set needTrans to true here. newXEncoding is computed based on argEncoding
// which is before the transpose. without needTrans we will compute vec and
// maxPhase based on incorrect m, n and k size of mma. the type inference of
// TransOp simply swap the order but doesn't fix the vec and maxPhase for
// the YType, hence it would causing incorrect swizzling code.
auto newXEncoding = triton::gpu::SharedEncodingAttr::get(
getContext(), ZEncoding, XType.getShape(), newXOrder,
XEncoding.getCTALayout(), XType.getElementType());
XEncoding.getCTALayout(), XType.getElementType(), true);
auto newXType = RankedTensorType::get(XType.getShape(),
XType.getElementType(), newXEncoding);
if (XEncoding == newXEncoding)
Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,10 +652,12 @@ void LoopPipeliner::createBufferTypes() {
.getEncoding()
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
// MMAv1 and MMAv2
bool needTrans = dyn_cast_or_null<tt::TransOp>(
cvt.getDefiningOp()->getOperand(0).getDefiningOp());
unsigned bitWidth = ty.getElementType().getIntOrFloatBitWidth();
sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(),
ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth);
ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth, needTrans);
} else {
// MMAv3
sharedEnc = ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(),
Expand Down