Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,17 @@ unsigned getNumWarpsPerCTA(Attribute layout);

unsigned getNumCTAs(Attribute layout);

// Return the order that represents that the batch is in row-major or
// column-major order for a batch of matrices of shape [*, m, n] with
// len(shape) == rank.
SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor);

// Return the order that represents that the dot operand is in kMajor
// (contiguous in the inner dimension) or it's contiguous on the outer
// dimension.
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kMajor);

bool isExpensiveCat(CatOp cat, Attribute targetEncoding);

// Return true if a view between the two types cannot be implemented as a no-op.
Expand Down
74 changes: 34 additions & 40 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,19 @@ static SmallVector<unsigned> eraseOrder(ArrayRef<unsigned> order,
return resOrder;
}

SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused about the name "matrix" when rank > 2

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We consider it as a batched matrix of shape [*, m, n], where * are zero or more dimensions (see the comment below).

// Return the order that represents that the batch is in row-major or
// column-major order for a batch of matrices of shape [*, m, n] with
// len(shape) == rank.
assert(rank >= 2);
SmallVector<unsigned> order(rank);
std::iota(order.rbegin(), order.rend(), 0);
if (!rowMajor) {
std::swap(order[0], order[1]);
}
return order;
}

SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kMajor) {
// kMajor: if true, the matrix is fastest-running on k,
Expand All @@ -244,15 +257,8 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
// batch (if rank == 3) is always the slowest running dimension
assert(rank == 2 || rank == 3);
assert(opIdx == 0 || opIdx == 1);
SmallVector<unsigned> order(rank);
std::iota(order.rbegin(), order.rend(), 0);
// If opIdx is 1 and kMajor is true, the order is [0, 1]
// (resp. [1, 2, 0] if rank == 3)
// Same if opIdx is 0 and kMajor is false
if (bool(opIdx) == kMajor) {
std::swap(order[0], order[1]);
}
return order;
auto rowMajor = bool(opIdx) != kMajor;
return getMatrixOrder(rank, rowMajor);
}

SmallVector<unsigned> getWarpOrder(Attribute layout) {
Expand All @@ -262,20 +268,21 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
}
}
auto order = getOrder(layout);
// FIXME: This mmaLayout if should just return
// getOrderForDotOperand(0, order.size(), kMajor=false)
// as mma has the same order as DotOperand(opIdx=0)
// FIXME: At the moment, warpOrder in Ampere is N-major but in Hopper it's
// M-major This is awkward. Since we can choose any warpOrder in Ampere, we
// should probably choose M-major and change `LinearLayoutConversion.cpp` and
// `MMAv2.cpp` to match.
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
if (mmaLayout.isHopper()) {
// Hopper MMA instructions force a warp order of [0, 1]. See docs:
// Hopper MMA instructions force warps to be column-major
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8
auto it = std::find(order.begin(), order.end(), 0);
order.erase(it);
order.insert(order.begin(), 0);
return getMatrixOrder(order.size(), /*rowMajor*/ false);
}
} else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(),
/*kMajor*/ false);
// It's quite weird to talk about warp order when that the warps
// are broadcasted along the K dimension
llvm::report_fatal_error(
"DotOperandEncoding::getWarpOrder not implemented");
}
return order;
}
Expand All @@ -285,11 +292,11 @@ SmallVector<unsigned> getOrder(Attribute layout) {
return llvm::to_vector(blockedLayout.getOrder());
}
if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(layout)) {
// Order doesn't really matter. We just have to be consistent when unpacking
// the elements in the MMAv2/V3 lowerings. We choose row-major
auto distributedLayout = cast<DistributedEncodingTrait>(layout);
auto rank = distributedLayout.getWarpsPerCTA().size();
SmallVector<unsigned> order(rank);
std::iota(order.rbegin(), order.rend(), 0);
return order;
return getMatrixOrder(rank, /*rowMajor*/ true);
}
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
auto rank = dotLayout.getWarpsPerCTA().size();
Expand Down Expand Up @@ -421,7 +428,7 @@ unsigned getNumWarpsPerCTA(Attribute layout) {
else if (auto wmmaLayout = dyn_cast<AMDWmmaEncodingAttr>(layout))
warpsPerCTA = wmmaLayout.getWarpsPerCTA();
else if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout))
return getNumWarpsPerCTA(dotLayout.getParent());
warpsPerCTA = dotLayout.getWarpsPerCTA();
else if (auto sharedLayout = dyn_cast<SharedEncodingAttr>(layout))
llvm::report_fatal_error("Cannot get numWarps from SharedEncodingAttr");
else
Expand Down Expand Up @@ -2136,25 +2143,12 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand(
SmallVector<unsigned> NvidiaMmaEncodingAttr::getShapePerCTATileForOperand(
ArrayRef<int64_t> shape, int kWidth, int opIdx) const {
assert(isAmpere() && "mmaLayout version = 1 is not implemented yet");
auto parentShapePerCTATile = getShapePerCTATile(shape);
auto rank = parentShapePerCTATile.size();
auto shapePerCTATile = getShapePerCTATile(shape);
auto rank = shapePerCTATile.size();
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
// 4 threads * 2 subtiles
unsigned kWidthTile = kWidth * 2 * 4;
if (opIdx == 0) {
if (rank == 2)
return {parentShapePerCTATile[rank - 2], kWidthTile};
else
return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2],
kWidthTile};
} else if (opIdx == 1) {
if (rank == 2)
return {kWidthTile, parentShapePerCTATile[rank - 1]};
else
return {parentShapePerCTATile[0], kWidthTile,
parentShapePerCTATile[rank - 1]};
} else {
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
}
shapePerCTATile[kDim] = kWidth * 2 * 4;
return shapePerCTATile;
}
SmallVector<unsigned>
NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
Expand Down
104 changes: 78 additions & 26 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank) {
return ret;
}

// TODO Have order be a mandatory argument of standardOutDimNames.
SmallVector<StringAttr> permuteDimNames(const SmallVector<StringAttr> &names,
const SmallVector<unsigned> &order) {
assert(names.size() == order.size());
SmallVector<StringAttr> ret;
for (unsigned i : order) {
ret.push_back(names[i]);
}
return ret;
}

void assertIsRegisterLayout(const LinearLayout &layout) {
assert(layout.getNumInDims() > 0);
MLIRContext *ctx = layout.getInDimNames().begin()->getContext();
Expand Down Expand Up @@ -281,15 +292,19 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef<int64_t> shape,

MLIRContext *ctx = mma.getContext();
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
auto orderedDimNames = permuteDimNames(dimNames, getOrder(mma));
// By using `reverse(dimNames)` below, we set the order to be row-major
assert(getOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));

LinearLayout ctaLayout(
{{S("register"), {{1, 0}, {0, 8}}},
{S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}},
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));

ctaLayout *= identityND(
S("warp"), mma.getWarpsPerCTA(),
llvm::to_vector(llvm::reverse(llvm::seq<unsigned>(rank))), dimNames);
ArrayRef(orderedDimNames).take_front(2));
assert(getWarpOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));
// FIXME(Lezcano). identityND should not have an `order` param as it's
// redundant with the order of the out dims.
ctaLayout *=
identityND(S("warp"), mma.getWarpsPerCTA(), mma.getWarpOrder(), dimNames);

return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
}
Expand Down Expand Up @@ -322,10 +337,14 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
ctaLayout *= LinearLayout::identity1D(n / ctaLayout.getOutDimSize(S("dim1")),
S("register"), S("dim1"));

// Expand the `warp` dimension according to warpsPerCTA.
//
// It's weird that this is order [0,1] when MMAv2's warpsPerCTA is [1,0], but
// this really does seem to be correct.
// The order given by choosing (`dim1`, `dim0`) is [1, 0], that is, N-major.
// Since the warpOrder needs to be M-major, we need to transpose the out
// dimensions AND transpose the order
// FIXME(Lezcano). identityND should not have an `order` param as it's

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree

// redundant. The order is already given by the order of the
// out dims, and if it has an order, it shouldn't change the
// order of the out dims.
assert(getWarpOrder(mma) == SmallVector<unsigned>({0, 1}));
ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1},
{S("dim0"), S("dim1")})
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
Expand Down Expand Up @@ -843,18 +862,24 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
DotOperandEncodingAttr dot) {
// TODO,BE. Implement ampereMMA in terms of this one
// Note that, even though MMAv2 looks similar to this layout, they are just
// the same at a register and lane level. The warps treatment is different!
int rank = shape.size();
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
int kWidth = dot.getKWidth();
bool isA = dot.getOpIdx() == 0;

assert(mma.isAmpere());
assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));
assert(mma.isAmpere());

MLIRContext *ctx = mma.getContext();
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
// A and B have kMajor order
assert(getOrder(dot) ==
getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true));

auto kMajorDims =
permuteDimNames(standardOutDimNames(ctx, rank), getOrder(dot));

// Implement A. For B transpose in the end
std::vector<std::vector<int32_t>> registers;
Expand All @@ -881,24 +906,51 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
}
registers.push_back({i, 0});

if (!isA) {
for (auto &r : registers) {
std::swap(r[0], r[1]);
LinearLayout ctaLayout({{S("register"), registers}, {S("lane"), lanes}},
ArrayRef(kMajorDims).take_front(2));

// Let warpsPerCTAMma = {2, 2}, then
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
// assume warpOrder = {0, 1}
// Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that
// the C is owned as per the following layout:
// C: 0 | 1
// - | -
// 2 | 3
// In order to be able to compute C, we need the following warp tiling of
// A and B:
// A: 0 1 | 0 1 B: 0 2 | 1 3
// - - | - - - - | - -
// 2 3 | 2 3 0 2 | 1 3
// In particular, for A and B we need to broadcast along K

assert(mma.getWarpOrder() == getMatrixOrder(rank, /*rowMajor=*/true));
auto warpsPerCTAMma = mma.getWarpsPerCTA();
std::vector<std::vector<int32_t>> warps;
if (isA) {
for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) {
warps.push_back({0, 0});
}
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
warps.push_back({0, i});
}
} else {
for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) {
warps.push_back({0, i});
}
for (auto &l : lanes) {
std::swap(l[0], l[1]);
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
warps.push_back({0, 0});
}
}
if (rank == 3) {
for (auto &w : warps) {
w.push_back(0);
}
}

LinearLayout ctaLayout(
{{S("register"), registers}, {S("lane"), lanes}},
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));

auto order = dot.getCTAOrder();
assert(order[0] == rank - 1 && order[1] == rank - 2);
ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames);
ctaLayout *= LinearLayout({{S("warp"), warps}}, kMajorDims);

return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape);
}

std::optional<LinearLayout>
Expand All @@ -907,7 +959,7 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
return mfmaDotToLinearLayout(*this, shape);
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (mma.getVersionMajor() == 2 && mma.getVersionMinor() == 0) {
if (mma.isAmpere()) {
return ampereDotToLinearLayout(shape, *this);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,15 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct(
}

if (dot.getOpIdx() == 1) {
// there are kWidth * 2 elems packed as bf16x2
int elemsInTile = dot.getKWidth();
// n0 and n1 are unrolled in the legacy path
// Unrolling n1 makes some sense, but unrolling n0 makes absolutely no
// sense IMO
// n0 is unrolled in the legacy path, which makes no sense
n0 *= 2;
n1 *= 2;
for (auto b = 0; b < batch; ++b)
for (auto j = 0; j < n1 / elemsInTile; ++j)
for (auto i = 0; i < n0; ++i)
for (auto k = 0; k < elemsInTile; ++k) {
vals[{b, i, elemsInTile * j + k}] = elems[offset++];
}
for (auto i = 0; i < n0; ++i)
for (auto j = 0; j < n1; ++j) {
vals[{b, i, 2 * j}] = elems[offset++];
vals[{b, i, 2 * j + 1}] = elems[offset++];
}
return vals;
}
}
Expand Down
41 changes: 37 additions & 4 deletions unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,14 +555,14 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) {
{2, 0},
{4, 0},
{32, 0},
{64, 0},
{0, 8},
{0, 16},
{0, 32},
{64, 0}}},
{0, 32}}},
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
{
S("warp"),
{},
{{0, 0}, {0, 0}},
},
{S("block"), {}},
},
Expand All @@ -582,13 +582,46 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) {
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
{
S("warp"),
{},
{{0, 0}, {0, 0}},
},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
}

TEST_F(LinearLayoutConversionsTest, DotMMAv2_split_warp_kwidth8) {
EXPECT_EQ(
toLinearLayout({32, 64}, dotMMAv2(0, 8, {2, 2})),
LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}},
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
{S("warp"), {{0, 0}, {16, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
EXPECT_EQ(
toLinearLayout({64, 16}, dotMMAv2(1, 8, {2, 2})),
LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}},
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
{S("warp"), {{0, 8}, {0, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(0, 8, {2, 2})),
LinearLayout(
{{S("register"),
{{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {32, 0}}},
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
{S("warp"), {{0, 0}, {16, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
EXPECT_EQ(
toLinearLayout({128, 32}, dotMMAv2(1, 8, {2, 2})),
LinearLayout(
{{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {0, 16}}},
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
{S("warp"), {{0, 8}, {0, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
}

TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) {
auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32,
/*isTransposed=*/false);
Expand Down