diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 3e8fcdf24587..b209a02b4bb3 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -902,10 +902,12 @@ inline void emitWmmaOffsetForCTA(const AMDWmmaEncodingAttr &wmmaLayout, auto rank = shapePerCta.size(); assert(rank == 2 || rank == 3); SmallVector elemOffset(rank, 0); + auto elemStride = wmmaLayout.getVersion() == 1 ? 2 : 1; if (rank == 3) elemOffset[0] = ctaBatchOffset; for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) { - elemOffset[rank - 2] = ctaOffsetX * shapePerCta[rank - 2] + 2 * elem; + elemOffset[rank - 2] = + ctaOffsetX * shapePerCta[rank - 2] + elemStride * elem; elemOffset[rank - 1] = ctaOffsetY * shapePerCta[rank - 1]; offsets.push_back(elemOffset); } @@ -951,8 +953,17 @@ emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter, SmallVector multiDimBase(rank); - multiDimBase[rank - 2] = - add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0); + auto ver = wmmaLayout.getVersion(); + if (ver == 1) { + multiDimBase[rank - 2] = + add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0); + } else { + assert(ver == 2); + multiDimBase[rank - 2] = + add(mul(udiv(threadIdPerWarp, i32_val(mnkDim[2])), + i32_val(wmmaLayout.getSizePerThread()[rank - 2])), + offWarp0); + } multiDimBase[rank - 1] = add(laneId, offWarp1); // TODO: It is assumed when rank = 3, warpsPerCTA is set to @@ -1102,8 +1113,6 @@ emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter, } else if (auto mfmaLayout = mlir::dyn_cast(layout)) { result = emitBaseIndexForMfmaLayout(loc, rewriter, mfmaLayout, type); } else if (auto wmmaLayout = mlir::dyn_cast(layout)) { - // TODO: support 2nd gen of WMMA - assert(wmmaLayout.getVersion() == 1); result = emitBaseIndexForWmmaLayout(loc, rewriter, wmmaLayout, type); } else if (auto sliceLayout = mlir::dyn_cast(layout)) { auto parentLayout = sliceLayout.getParent(); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 458a1ed9d9b6..893d15876dc0 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -814,8 +814,6 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, emitMfmaOffsetForCTA(mfmaLayout, offsets, 0, multiDimCTAInRepId[0], multiDimCTAInRepId[1]); } else if (auto wmmaLayout = dyn_cast(layout)) { - // TODO: support 2nd gen of WMMA - assert(wmmaLayout.getVersion() == 1); emitWmmaOffsetForCTA(wmmaLayout, offsets, 0, multiDimCTAInRepId[0], multiDimCTAInRepId[1]); } diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index a65b9e64e2a5..a4f30fc503ba 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -565,17 +565,35 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef shape) const { // For wmma with 16x16 output, each of the 32 threads holds 8 elements. // - // For the register (i.e., element) dimension, these 8 elements are along - // the matrix C's M dimension, with 1 consecutive elements spanning 1 row - // and then the next 1 row being a gap. + // The first version of WMMA layout has following specific: + // for the register (i.e., element) dimension, these 8 elements are + // along the matrix C's M dimension, with 1 consecutive elements + // spanning 1 row and then the next 1 row being a gap. // // For the lane (i.e., thread) dimension, these threads are along the // matrix C's N dimension, with 16 consecutive threads covering a whole // row and the next 16 threads start at the next row. - LinearLayout tileLayout( - {{kRegister, {/*gap*/ {0, 2}, {0, 4}, {0, 8}}}, - {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 1}}}}, - {outDimNames[order[0]], outDimNames[order[1]]}); + // + // The second version of wmma layout is less tricky: + // for the register dimension 8 elements are along the matrix C's M + // dimension. First 16 lanes take 0-8 elems along M, second 16 take 8-15. + // We have 16 pair of threads in each warp, one pair covers the whole + // column. + // + // Please also check explaining comments in TritonGPUAttrDefs.td at the + // AMDWmmaEncodingAttr section. + unsigned ver = getVersion(); + assert(ver == 1 || ver == 2); + LinearLayout tileLayout = + ver == 1 + ? LinearLayout( + {{kRegister, {/*gap*/ {0, 2}, {0, 4}, {0, 8}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 1}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}) + : LinearLayout( + {{kRegister, {{0, 1}, {0, 2}, {0, 4}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 8}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); if (hasBatchDim) { assert(order[2] == 0); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp index 79fa319ba978..23dd5d37d520 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -77,8 +77,6 @@ createNewConvertOps(ModuleOp &mod, OpBuilder &builder, srcType.getElementType(), newMfmaEnc); } else if (auto srcWmma = dyn_cast( srcType.getEncoding())) { - // TODO: support 2nd gen of WMMA - assert(srcWmma.getVersion() == 1); auto newWmmaEnc = triton::gpu::AMDWmmaEncodingAttr::get( mod.getContext(), srcWmma.getVersion(), {warpsPerCtaX, warpsPerCtaY}, srcWmma.getCTALayout());