Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -902,10 +902,12 @@ inline void emitWmmaOffsetForCTA(const AMDWmmaEncodingAttr &wmmaLayout,
auto rank = shapePerCta.size();
assert(rank == 2 || rank == 3);
SmallVector<unsigned> elemOffset(rank, 0);
auto elemStride = wmmaLayout.getVersion() == 1 ? 2 : 1;
if (rank == 3)
elemOffset[0] = ctaBatchOffset;
for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) {
elemOffset[rank - 2] = ctaOffsetX * shapePerCta[rank - 2] + 2 * elem;
elemOffset[rank - 2] =
ctaOffsetX * shapePerCta[rank - 2] + elemStride * elem;
elemOffset[rank - 1] = ctaOffsetY * shapePerCta[rank - 1];
offsets.push_back(elemOffset);
}
Expand Down Expand Up @@ -951,8 +953,17 @@ emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter,

SmallVector<Value> multiDimBase(rank);

multiDimBase[rank - 2] =
add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0);
auto ver = wmmaLayout.getVersion();
if (ver == 1) {
multiDimBase[rank - 2] =
add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0);
} else {
assert(ver == 2);
multiDimBase[rank - 2] =
add(mul(udiv(threadIdPerWarp, i32_val(mnkDim[2])),
i32_val(wmmaLayout.getSizePerThread()[rank - 2])),
offWarp0);
}
multiDimBase[rank - 1] = add(laneId, offWarp1);

// TODO: It is assumed when rank = 3, warpsPerCTA is set to
Expand Down Expand Up @@ -1102,8 +1113,6 @@ emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter,
} else if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(layout)) {
result = emitBaseIndexForMfmaLayout(loc, rewriter, mfmaLayout, type);
} else if (auto wmmaLayout = mlir::dyn_cast<AMDWmmaEncodingAttr>(layout)) {
// TODO: support 2nd gen of WMMA
assert(wmmaLayout.getVersion() == 1);
result = emitBaseIndexForWmmaLayout(loc, rewriter, wmmaLayout, type);
} else if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
auto parentLayout = sliceLayout.getParent();
Expand Down
2 changes: 0 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -814,8 +814,6 @@ SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
emitMfmaOffsetForCTA(mfmaLayout, offsets, 0, multiDimCTAInRepId[0],
multiDimCTAInRepId[1]);
} else if (auto wmmaLayout = dyn_cast<AMDWmmaEncodingAttr>(layout)) {
// TODO: support 2nd gen of WMMA
assert(wmmaLayout.getVersion() == 1);
emitWmmaOffsetForCTA(wmmaLayout, offsets, 0, multiDimCTAInRepId[0],
multiDimCTAInRepId[1]);
}
Expand Down
32 changes: 25 additions & 7 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,17 +565,35 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

// For wmma with 16x16 output, each of the 32 threads holds 8 elements.
//
// For the register (i.e., element) dimension, these 8 elements are along
// the matrix C's M dimension, with 1 consecutive elements spanning 1 row
// and then the next 1 row being a gap.
// The first version of WMMA layout has following specific:
// for the register (i.e., element) dimension, these 8 elements are
// along the matrix C's M dimension, with 1 consecutive elements
// spanning 1 row and then the next 1 row being a gap.
//
// For the lane (i.e., thread) dimension, these threads are along the
// matrix C's N dimension, with 16 consecutive threads covering a whole
// row and the next 16 threads start at the next row.
LinearLayout tileLayout(
{{kRegister, {/*gap*/ {0, 2}, {0, 4}, {0, 8}}},
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 1}}}},
{outDimNames[order[0]], outDimNames[order[1]]});
//
// The second version of wmma layout is less tricky:
// for the register dimension 8 elements are along the matrix C's M
// dimension. First 16 lanes take 0-8 elems along M, second 16 take 8-15.
// We have 16 pair of threads in each warp, one pair covers the whole
// column.
//
// Please also check explaining comments in TritonGPUAttrDefs.td at the
// AMDWmmaEncodingAttr section.
unsigned ver = getVersion();
assert(ver == 1 || ver == 2);
LinearLayout tileLayout =
ver == 1
? LinearLayout(
{{kRegister, {/*gap*/ {0, 2}, {0, 4}, {0, 8}}},
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 1}}}},
{outDimNames[order[0]], outDimNames[order[1]]})
: LinearLayout(
{{kRegister, {{0, 1}, {0, 2}, {0, 4}}},
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 8}}}},
{outDimNames[order[0]], outDimNames[order[1]]});

if (hasBatchDim) {
assert(order[2] == 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ createNewConvertOps(ModuleOp &mod, OpBuilder &builder,
srcType.getElementType(), newMfmaEnc);
} else if (auto srcWmma = dyn_cast<triton::gpu::AMDWmmaEncodingAttr>(
srcType.getEncoding())) {
// TODO: support 2nd gen of WMMA
assert(srcWmma.getVersion() == 1);
auto newWmmaEnc = triton::gpu::AMDWmmaEncodingAttr::get(
mod.getContext(), srcWmma.getVersion(), {warpsPerCtaX, warpsPerCtaY},
srcWmma.getCTALayout());
Expand Down