Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,8 @@ SmallVector<unsigned> DotOperandEncodingAttr::getCTASplitNum() const {
assert(rank == 2 || rank == 3 && "Invalid dotLayout");

// Do not split CTA in K dimension
getOpIdx() == 0 ? res[rank - 1] = 1 : res[rank - 2] = 1;
auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2;
res[kDim] = 1;
return res;
}
SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
Expand Down
328 changes: 143 additions & 185 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,78 +280,6 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
return ret;
}

LinearLayout ampereMmaToLinearLayout(ArrayRef<int64_t> shape,
NvidiaMmaEncodingAttr mma) {
int rank = shape.size();

assert(mma.isAmpere());
assert(rank == 2 || rank == 3);
assert(mma.getInstrShape().size() == rank);
assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));

MLIRContext *ctx = mma.getContext();
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);

auto orderedDimNames = permuteDimNames(dimNames, mma.getRepOrder());
assert(mma.getRepOrder() == getMatrixOrder(rank, /*rowMajor=*/true));

LinearLayout ctaLayout(
{{S("register"), {{1, 0}, {0, 8}}},
{S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}},
ArrayRef(orderedDimNames).take_front(2));
assert(getWarpOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));
// FIXME(Lezcano). identityND should not have an `order` param as it's
// redundant with the order of the out dims.
ctaLayout *=
identityND(S("warp"), mma.getWarpsPerCTA(), mma.getWarpOrder(), dimNames);

return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
}

LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
NvidiaMmaEncodingAttr mma) {
int rank = shape.size();
assert(mma.isHopper());
assert(rank == 2);

// wgmma operates on groups of 4 warps.
assert(product(mma.getWarpsPerCTA()) % 4 == 0);

// Check that it's a known MMA layout.
assert(mma.getInstrShape().size() == 3);
int m = mma.getInstrShape()[0];
int n = mma.getInstrShape()[1];
int k = mma.getInstrShape()[2];
assert(m == 16);
assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256);
assert(k == 8 || k == 16 || k == 32);

MLIRContext *ctx = mma.getContext();
LinearLayout ctaLayout(
{{S("register"), {{1, 0}, {0, 8}}},
{S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}},
{S("dim1"), S("dim0")});

// Expand the `register` dimension so the size of dim1 matches `n`.
ctaLayout *= LinearLayout::identity1D(n / ctaLayout.getOutDimSize(S("dim1")),
S("register"), S("dim1"));

// The order given by choosing (`dim1`, `dim0`) is [1, 0], that is, N-major.
// Since the warpOrder needs to be M-major, we need to transpose the out
// dimensions AND transpose the order
// FIXME(Lezcano). identityND should not have an `order` param as it's
// redundant. The order is already given by the order of the
// out dims, and if it has an order, it shouldn't change the
// order of the out dims.
assert(getWarpOrder(mma) == SmallVector<unsigned>({0, 1}));
ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1},
{S("dim0"), S("dim1")})
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));

return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
}

LinearLayout sharedToLinearLayoutNoLeadingOffset(ArrayRef<int64_t> shape,
SharedEncodingAttr shared) {
assert(!shared.getHasLeadingOffset());
Expand Down Expand Up @@ -779,13 +707,153 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
}

LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
unsigned kWidth, ArrayRef<unsigned> order,
ArrayRef<unsigned> repOrder) {
// Trivial layout mapping 0 -> (0, 0), but we set the order to repOrder
int rank = repOrder.size();
auto dimNames = standardOutDimNames(ctx, rank);
auto trivialShape = SmallVector<unsigned>(rank, 1);
LinearLayout ctaLayout =
identityND(S("register"), trivialShape, repOrder, dimNames);

assert(rank >= 2);
auto inner = order[0];
auto outer = order[1];

assert(tileShape.size() == rank);
int m = tileShape[outer];
int n = tileShape[inner];

// The relative order of registers and lanes is given by:
// - Inner dim: kWidth registers
// - Inner dim: 4 lanes
// - Outer dim: 8 lanes
// - Outer dim: repeat m / 8 times
// - Inner dim: repeat n / (kWidth * 4) times
assert(m % 8 == 0);
assert(n % (kWidth * 4) == 0);
// There is at least one subtile on the inner-most dimension
// FIXME. We should implement operator* in terms of operator*=
// and chain *= instead of using *
auto outDimNames = llvm::to_vector(ctaLayout.getOutDimNames());
ctaLayout = ctaLayout *
LinearLayout::identity1D(kWidth, S("register"), dimNames[inner]) *
LinearLayout::identity1D(4, S("lane"), dimNames[inner]) *
LinearLayout::identity1D(8, S("lane"), dimNames[outer]) *
LinearLayout::identity1D(m / 8, S("register"), dimNames[outer]) *
LinearLayout::identity1D(n / (kWidth * 4), S("register"),
dimNames[inner]);
return ctaLayout;
}

std::optional<LinearLayout>
NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
auto ctx = getContext();
int rank = shape.size();

SmallVector<unsigned> tileShape;
if (isAmpere()) {
return ampereMmaToLinearLayout(shape, *this);
// Ampere.getInstrShape() returns the tile shape
tileShape = SmallVector<unsigned>(getInstrShape());
} else {
assert(isHopper());
auto instrShapeMNK = getInstrShape();
tileShape = SmallVector<unsigned>({instrShapeMNK[0], instrShapeMNK[1]});
}
if (isHopper()) {
return hopperMmaToLinearLayout(shape, *this);
// nvidiamma layout always assumes kWidth = 2
constexpr auto kWidth = 2;
auto ctaLayout =
nvidiaMmaTile(ctx, tileShape, kWidth, getOrder(*this), getRepOrder());

// The triton orders are defined on [dim0, dim1, ...], so we need to pass
// those dims Then, for some reason, operator* requires the orders to match
// so we need to reorder the outs to match
// FIXME(Lezcano). identityND should not take a dim name, as it's redundant.
// The order in triton assumes the standardDims, so it should
// use those.
ctaLayout *= identityND(S("warp"), getWarpsPerCTA(), getWarpOrder(),
standardOutDimNames(ctx, rank))
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));

return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
}

LinearLayout warpsNvidiaDot(MLIRContext *ctx, ArrayRef<unsigned> mmaWarpShape,
ArrayRef<unsigned> mmaWarpOrder, bool isA) {
// Let warpsPerCTAMma = {2, 2}, then
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
// assume warpOrder = {1, 0}
// Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that
// the C is owned as per the following layout:
// C: 0 | 1
// - | -
// 2 | 3
// In order to be able to compute C, we need the following warp tiling of
// A and B:
// A: 0 1 | 0 1 B: 0 2 | 1 3
// - - | - - - - | - -
// 2 3 | 2 3 0 2 | 1 3
// In other words, we need to broadcast along K
auto rank = mmaWarpOrder.size();
auto inner = isA ? rank - 1 : rank - 2;
auto outer = isA ? rank - 2 : rank - 1;
auto dimNames = standardOutDimNames(ctx, rank);
auto trivialShape = SmallVector<unsigned>(rank, 1);
LinearLayout warpLayout =
identityND(S("warp"), trivialShape, mmaWarpOrder, dimNames);

// We have to broadcast along the inner dimension
// For A, when moving along M we go from 0 to 2.
// For B, when moving along N we go from 0 to 1.
// As such, choosing the order of A {1, 0}, gives us the correct broadcasting
// Same happens if the mmaWarpOrder is {0, 1}, like in Hopper
for (auto d : mmaWarpOrder) {
if (d == inner) {
warpLayout *=
LinearLayout::zeros1D(mmaWarpShape[d], S("warp"), dimNames[d]);
} else {
warpLayout *=
LinearLayout::identity1D(mmaWarpShape[d], S("warp"), dimNames[d]);
}
}
return warpLayout;
}

LinearLayout nvidiaDotToLinearLayout(ArrayRef<int64_t> shape,
DotOperandEncodingAttr dot) {
int rank = shape.size();
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
int kWidth = dot.getKWidth();
bool isA = dot.getOpIdx() == 0;
MLIRContext *ctx = mma.getContext();

SmallVector<unsigned> tileShape(rank, 1);
if (isA) {
tileShape[rank - 2] = 16;
tileShape[rank - 1] = kWidth * 8;
} else {
// Hopper takes the rhs via shared memory
assert(mma.isAmpere());
tileShape[rank - 2] = kWidth * 8;
tileShape[rank - 1] = 8;
}
auto ctaLayout =
nvidiaMmaTile(ctx, tileShape, kWidth, getOrder(dot), dot.getRepOrder());
ctaLayout *=
warpsNvidiaDot(ctx, mma.getWarpsPerCTA(), mma.getWarpOrder(), isA)
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));

return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape);
}

std::optional<LinearLayout>
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
auto parent = getParent();
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
return mfmaDotToLinearLayout(*this, shape);
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
return nvidiaDotToLinearLayout(shape, *this);
}
return std::nullopt;
}
Expand Down Expand Up @@ -860,116 +928,6 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
return ret;
}

LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
DotOperandEncodingAttr dot) {
// Note that, even though MMAv2 looks similar to this layout, they are just
// the same at a register and lane level. The warps treatment is different!
int rank = shape.size();
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
int kWidth = dot.getKWidth();
bool isA = dot.getOpIdx() == 0;

assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));
assert(mma.isAmpere());

MLIRContext *ctx = mma.getContext();

// The A and B operands are tiled in a kMajor fashion
auto kMajorOrder = dot.getRepOrder();
assert(kMajorOrder ==
getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true));

auto kMajorDims =
permuteDimNames(standardOutDimNames(ctx, rank), kMajorOrder);
// This agrees with the order of the elements, which means that we can share
// the code below for both A and B without having to perform any swaps
assert(getOrder(dot) == kMajorOrder);

std::vector<std::vector<int32_t>> registers;
std::vector<std::vector<int32_t>> lanes;
int32_t i = 1;
// kWidth contiguous elements
while (i < kWidth) {
registers.push_back({i, 0});
i *= 2;
}
// 4 threads per chunk
for (int j = 0; j < 2; j++) {
lanes.push_back({i, 0});
i *= 2;
}
// 8 threads going down
lanes.push_back({0, 1});
lanes.push_back({0, 2});
lanes.push_back({0, 4});
// 2 tiles in column-major order
// Just one if it's the B operand
if (isA) {
registers.push_back({0, 8});
}
registers.push_back({i, 0});

LinearLayout ctaLayout({{S("register"), registers}, {S("lane"), lanes}},
ArrayRef(kMajorDims).take_front(2));

// Let warpsPerCTAMma = {2, 2}, then
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
// assume warpOrder = {0, 1}
// Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that
// the C is owned as per the following layout:
// C: 0 | 1
// - | -
// 2 | 3
// In order to be able to compute C, we need the following warp tiling of
// A and B:
// A: 0 1 | 0 1 B: 0 2 | 1 3
// - - | - - - - | - -
// 2 3 | 2 3 0 2 | 1 3
// In particular, for A and B we need to broadcast along K

assert(mma.getWarpOrder() == getMatrixOrder(rank, /*rowMajor=*/true));
auto warpsPerCTAMma = mma.getWarpsPerCTA();
std::vector<std::vector<int32_t>> warps;
if (isA) {
for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) {
warps.push_back({0, 0});
}
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
warps.push_back({0, i});
}
} else {
for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) {
warps.push_back({0, i});
}
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
warps.push_back({0, 0});
}
}
if (rank == 3) {
Comment thread
lezcano marked this conversation as resolved.
for (auto &w : warps) {
w.push_back(0);
}
}

ctaLayout *= LinearLayout({{S("warp"), warps}}, kMajorDims);

return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape);
}

std::optional<LinearLayout>
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
auto parent = getParent();
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
return mfmaDotToLinearLayout(*this, shape);
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (mma.isAmpere()) {
return ampereDotToLinearLayout(shape, *this);
}
}
return std::nullopt;
}

std::optional<LinearLayout>
toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
std::optional<int32_t> elemBitWidth /*= std::nullopt*/) {
Expand Down
Loading