Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);

bool atomicNeedsSharedMemory(Value result);

bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT);
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

Expand Down
7 changes: 7 additions & 0 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,13 @@ class LinearLayout {
// (i.e. every input bit affects the output).
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks() const;

// Increase an input dimension without affecting the output dimension. The
// added free variables are mapped to 0, ensuring that the new input
// dimensions correspond directly to the existing output space. The function
// errors out if `newInDimSize` is less than the current size or the new size
// is not a power of 2.
LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const;

std::string toString() const;

friend bool operator==(LinearLayout lhs, LinearLayout rhs);
Expand Down
42 changes: 40 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ bool supportMMA(Value value, int version) {
(elemTy.isInteger(8) && version >= 2);
}

bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (blockedLayout == nullptr || dotOperandLayout == nullptr)
Expand Down Expand Up @@ -655,8 +655,46 @@ std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
if (!(srcLayout.has_value() && dstLayout.has_value()))
return std::nullopt;
StringAttr kRegister = StringAttr::get(ctx, "register");
StringAttr kLane = StringAttr::get(ctx, "lane");
StringAttr kWarp = StringAttr::get(ctx, "warp");
StringAttr kBlock = StringAttr::get(ctx, "block");
auto numSrcRegs = srcLayout->getInDimSize(kRegister);
auto numDstRegs = dstLayout->getInDimSize(kRegister);
// The `invertAndCompose` function will generate a layout that is injective
// by assigning new output dimensions to free variables. For instance,
// consider a scenario where `srcLayout` has a free variable in the lane
// dimension, while `dstLayout` has two free variables in the lane
// dimension and also a larger number of registers.
// The injective form of `srcLayout` will add only a single additional row
// to the transformation matrix, whereas the injective form of `dstLayout`
// will add two additional rows. This discrepancy causes misleading results
// because the matrices end up with a different number of rows.
//
// Take `dstLayout ⋅ srcLayout^-1` as an example:
//
// - `injective(dstLayout)`: [n, m] → [n + 2, m]
// - `injective(srcLayout)`: [n, m] → [n + 1, m]
// - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1]
// - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n +
// 1] → [n + 2, n + 1]
//
// Here, the `(n + 1)`-th row added by `dstLayout` represents the free
// variable in registers, and the `(n + 2)`-th row represents the free
// variable in lanes. However, the `(n + 1)`-th row added by `srcLayout`
// represents the free variable in lanes. As a result, the `(n + 1)`-th row
// in two layouts do not correspond to the same free variable.
//
// To address this issue, we pad the free variables in `srcLayout` and
// `dstLayout` to ensure they have the same number of registers. This
// guarantees that the resulting matrices have the same number of rows,
// ensuring consistency in the composition process.
auto numRegs = std::max(numSrcRegs, numDstRegs);
auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs);
auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs);
// comp describes the layout function to create dst from src.
LinearLayout comp = dstLayout->invertAndCompose(*srcLayout);
LinearLayout comp =
dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs);
// We try to quotient by the largest subspace first
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
for (auto dim : dims) {
Expand Down
53 changes: 32 additions & 21 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,60 +288,71 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
return rewriter.notifyMatchFailure(
op, "NYI. srcTy and/or dstTy don't implement LLs yet");
}
LinearLayout srcLayout =
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
LinearLayout dstLayout =
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());

StringAttr kBlock = str_attr("block");
StringAttr kWarp = str_attr("warp");
StringAttr kLane = str_attr("lane");
StringAttr kRegister = str_attr("register");

assert(to_vector(conversion->getInDimNames()) ==
to_vector(conversion->getOutDimNames()));
auto dims = conversion->getInDimNames();
if (llvm::is_contained(dims, str_attr("block"))) {
if (llvm::is_contained(dims, kBlock)) {
// Case 1: Transfer between values in different CTAs.
// This requires moving values through distributed shared memory.
return rewriter.notifyMatchFailure(
op, "NYI: Transfer between different CTAs");
} else if (llvm::is_contained(dims, str_attr("warp"))) {
} else if (llvm::is_contained(dims, kWarp)) {
// Case 2: Transfer between values in the same CTA, in which case we move
// values through shared memory.
LinearLayout srcLayout =
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
LinearLayout dstLayout =
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
} else if (llvm::is_contained(dims, str_attr("lane"))) {
} else if (llvm::is_contained(dims, kLane)) {
// Case 3. Transfer between values in the same warp, in which case we try
// to move values using warp shuffles, though if the pattern is
// complicated enough we may fall back to using shared memory
// TODO(Keren): implement warp shuffle instead of using the general
// approach that uses shared memory
LinearLayout srcLayout =
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
LinearLayout dstLayout =
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
} else if (llvm::is_contained(dims, str_attr("register"))) {
} else if (llvm::is_contained(dims, kRegister) ||
dstLayout.getInDimSize(kRegister) !=
srcLayout.getInDimSize(kRegister)) {
// Case 4. Transfer between values in the same thread, in which case we
// simply reorder the elements of adaptor.getSrc().
return transferWithinThread(op, *conversion, adaptor, rewriter);
return transferWithinThread(
op, dstLayout.getFreeVariableMasks()[kRegister],
dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter);
} else {
// The two layouts are equivalent. We should probably remove these in
// RemoveLayoutConversion.
// Cast 5. The two layouts are equivalent. We should probably remove
// these in RemoveLayoutConversion.
rewriter.replaceOp(op, adaptor.getSrc());
return success();
}
}

LogicalResult
transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion,
OpAdaptor adaptor,
transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs,
const LinearLayout &conversion, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MLIRContext *ctx = op.getContext();
auto loc = op.getLoc();
StringAttr kRegister = str_attr("register");
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));

auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
SmallVector<Value> outVals;
outVals.resize(conversion.getInDimSize(kRegister));
for (int i = 0; i < conversion.getInDimSize(kRegister); i++) {
auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second;
SmallVector<Value> outVals(numRegs);
for (int i = 0; i < outVals.size(); i++) {
// Remove free masks from the register index
// For example, if idx = 0b00111, and masks = 0b00100, then we get
// 0b00011. It means that register 7 (0b111) has the same value as
// register 3 (0b011).
auto idx = i & (~regMasks);
auto srcIdx = conversion.hasInDim(kRegister)
? conversion.apply({{kRegister, idx}}).begin()->second
: idx;
outVals[i] = inVals[srcIdx];
}
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
Expand Down
15 changes: 15 additions & 0 deletions lib/Tools/LinearLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,21 @@ bool LinearLayout::equalIgnoringOutDimSizes(const LinearLayout &other) const {
return true;
}

LinearLayout LinearLayout::resize(StringAttr inDim,
int32_t newInDimSize) const {
BasesT bases = getBases();
assert(bases.contains(inDim) && "inDim not in layout");
assert(llvm::isPowerOf2_32(newInDimSize) &&
"newInDimSize must be a power of 2");
assert(newInDimSize >= getInDimSize(inDim) &&
"newInDimSize must be >= old size");
auto numFreeVariables = llvm::Log2_32(newInDimSize) - getInDimSizeLog2(inDim);
for (int i = 0; i < numFreeVariables; i++) {
bases[inDim].push_back(std::vector<int32_t>(getNumOutDims(), 0));
}
return LinearLayout(std::move(bases), llvm::to_vector(getOutDimNames()));
}

std::string LinearLayout::toString() const {
// Start with a newline because we print out a bulleted list; it doesn't
// make sense for the first line of this list to be on the same line as
Expand Down
74 changes: 74 additions & 0 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,80 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :

// -----

#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_layout_mmav2_dot_reg
tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) {
// CHECK-NOT: st.shared
// CHECK-NOT: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1>
tt.return
}
}

// -----

#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: convert_layout_mmav3_mmav3_0
tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) {
// CHECK-NOT: st.shared
// CHECK-NOT: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1>
tt.return
}
}

// -----

#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: convert_layout_mmav3_mmav3_1
tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) {
// CHECK-NOT: st.shared
// CHECK-NOT: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0>
tt.return
}
}

// -----

#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: convert_layout_mmav3_mmav3_2
tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) {
// CHECK-NOT: st.shared
// CHECK-NOT: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0>
tt.return
}
}

// -----

#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: convert_layout_mmav3_mmav3_3
tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) {
// CHECK-NOT: st.shared
// CHECK-NOT: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0>
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} {
Expand Down
33 changes: 33 additions & 0 deletions unittest/Tools/LinearLayoutTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,39 @@ TEST_F(LinearLayoutTest, QuotientIdentityMultipleDimensions) {
ASSERT_TRUE(quotientLayout->quotient({S("dim2")}).has_value());
}

TEST_F(LinearLayoutTest, Resize) {
auto init = LinearLayout(
{
{S("in0"), {{0, 1}, {0, 2}}},
{S("in1"), {{1, 0}, {2, 0}}},
{S("in2"), {}},
},
{S("dim0"), S("dim1")});
EXPECT_EQ(init.resize(S("in0"), 8),
LinearLayout(
{
{S("in0"), {{0, 1}, {0, 2}, {0, 0}}},
{S("in1"), {{1, 0}, {2, 0}}},
{S("in2"), {}},
},
{S("dim0"), S("dim1")}));
EXPECT_EQ(init.resize(S("in0"), 4), LinearLayout(
{
{S("in0"), {{0, 1}, {0, 2}}},
{S("in1"), {{1, 0}, {2, 0}}},
{S("in2"), {}},
},
{S("dim0"), S("dim1")}));
EXPECT_EQ(init.resize(S("in1"), 8),
LinearLayout(
{
{S("in0"), {{0, 1}, {0, 2}}},
{S("in1"), {{1, 0}, {2, 0}, {0, 0}}},
{S("in2"), {}},
},
{S("dim0"), S("dim1")}));
}

} // anonymous namespace
} // namespace mlir::triton

Expand Down