diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 2e3797af940a..7cfe7f448d42 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -113,6 +113,16 @@ compared to 1*64 when the hasLeadingOffset is false. "ArrayRef":$order, "CTALayoutAttr":$CTALayout, "unsigned":$typeWidthInBit), [{ + bool needTrans = false; // default value + return get(context, dotOpEnc, shape, order, CTALayout, typeWidthInBit, needTrans); + }]>, + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "unsigned":$typeWidthInBit, + "bool":$needTrans), [{ auto mmaEnc = dotOpEnc.getParent().dyn_cast(); if(!mmaEnc) @@ -152,16 +162,23 @@ compared to 1*64 when the hasLeadingOffset is false. // --- handle A operand --- if (opIdx == 0) { // compute swizzling for A operand - int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m - int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2]; + int m = (needTrans) ? matShape[2] : matShape[0]; + int k = (needTrans) ? matShape[0] : matShape[2]; + int vec = (order[0] == 1) ? k : m; + int mmaStride = (order[0] == 1) ? m : k; int maxPhase = mmaStride / perPhase; return get(context, vec, perPhase, maxPhase, order, CTALayout); } // --- handle B operand --- if (opIdx == 1) { - int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k - int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1]; + // we compute vec and maxPhase m, n and k size of the mma + // instruction. when matmul operands is transposed, we should + // consider that to get m, n and k. + int n = needTrans ? matShape[2] : matShape[1]; + int k = needTrans ? matShape[1] : matShape[2]; + int vec = (order[0] == 1) ? n : k; + int mmaStride = (order[0] == 1) ? k : n; int maxPhase = mmaStride / perPhase; return get(context, vec, perPhase, maxPhase, order, CTALayout); } @@ -189,6 +206,16 @@ compared to 1*64 when the hasLeadingOffset is false. return get(context, dotOpEnc, shape, order, CTALayout, bitwidth); }]>, + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy, + "bool":$needTrans), [{ + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + return get(context, dotOpEnc, shape, order, CTALayout, bitwidth, needTrans); + }]>, + AttrBuilder<(ins "ArrayRef":$shape, "ArrayRef":$order, "CTALayoutAttr":$CTALayout, diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 14a050472800..5a1d93c2569a 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -60,9 +60,14 @@ class ConvertTransConvert : public mlir::RewritePattern { // used here. For tests where numCTAs = 1, this is not a problem since all // CTALayouts are the same. auto newXOrder = triton::gpu::getOrder(argEncoding); + // set needTrans to true here. newXEncoding is computed based on argEncoding + // which is before the transpose. without needTrans we will compute vec and + // maxPhase based on incorrect m, n and k size of mma. the type inference of + // TransOp simply swap the order but doesn't fix the vec and maxPhase for + // the YType, hence it would causing incorrect swizzling code. auto newXEncoding = triton::gpu::SharedEncodingAttr::get( getContext(), ZEncoding, XType.getShape(), newXOrder, - XEncoding.getCTALayout(), XType.getElementType()); + XEncoding.getCTALayout(), XType.getElementType(), true); auto newXType = RankedTensorType::get(XType.getShape(), XType.getElementType(), newXEncoding); if (XEncoding == newXEncoding) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index db5513d92cfb..13de5d266cb6 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -652,10 +652,12 @@ void LoopPipeliner::createBufferTypes() { .getEncoding() .dyn_cast()) { // MMAv1 and MMAv2 + bool needTrans = dyn_cast_or_null( + cvt.getDefiningOp()->getOperand(0).getDefiningOp()); unsigned bitWidth = ty.getElementType().getIntOrFloatBitWidth(); sharedEnc = ttg::SharedEncodingAttr::get( ty.getContext(), dotOpEnc, ty.getShape(), - ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth); + ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth, needTrans); } else { // MMAv3 sharedEnc = ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(),