Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,22 @@ class LinearLayout {
return *this;
}

// Compute a C such that A = B * C if it exists.
// In other words, C = B^{-1} * A.
// Note that such a C exists iff (every pair of input/output dim of) A is
// of the form
// [[B, 0],
// [0, C]]
// as a matrix, whenever those dimensions are present in B.
//
// C will always have the same input/output dimensions as A.
// When there are dimensions of size 1 there is some ambiguity in the
// division, as in `operator*` we treat missing dimensions as dimensions
// of size 1 whenever it makes sense to do so. The rule that C has the
// same dimensions as A ensures that C is well-defined.
friend std::optional<LinearLayout> divideLeft(const LinearLayout &A,
const LinearLayout &B);

// Returns true if this layout acts trivially (as the identity) on the given
// dimensions. This means that it's the identity on those dimensions, and it
// does not map other dimensions onto those or these onto other dimensions.
Expand Down
74 changes: 74 additions & 0 deletions lib/Tools/LinearLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,80 @@ LinearLayout LinearLayout::concatOuts(const LinearLayout &other) const {
/*requiresSurjective=*/false);
}

std::optional<LinearLayout> divideLeft(const LinearLayout &A,
const LinearLayout &B) {
// Compute a C such that A = B * C if it exists.
// Note that such a C exists iff (every pair of input/output dim of) A is of
// the form
// [[B, 0],
// [0, C]]
// as a matrix, whenever those dimensions are present in B.
for (StringAttr dim : B.getInDimNames()) {
if (!llvm::is_contained(A.getInDimNames(), dim))
return std::nullopt;
}
for (StringAttr dim : B.getOutDimNames()) {
if (!llvm::is_contained(A.getOutDimNames(), dim))
return std::nullopt;
}
// Compute candidate C's log-sizes for output dimensions.
llvm::MapVector<StringAttr, int32_t> cOutDimSizes;
for (StringAttr outDim : A.getOutDimNames()) {
int outA = A.getOutDimSizeLog2(outDim);
int outB = B.hasOutDim(outDim) ? B.getOutDimSizeLog2(outDim) : 0;
int outC = outA - outB;
if (outC < 0)
return std::nullopt;
cOutDimSizes[outDim] = 1 << outC;
}

LinearLayout::BasesT cBases;
for (StringAttr inDim : A.getInDimNames()) {
int inA = A.getInDimSizeLog2(inDim);
int inB = B.hasInDim(inDim) ? B.getInDimSizeLog2(inDim) : 0;
int inC = inA - inB;
if (inC < 0)
return std::nullopt;

std::vector<std::vector<int32_t>> basesForDim;
// Check that A’s first inB entries agree with B.
for (int i = 0; i < inB; ++i) {
for (StringAttr outDim : A.getOutDimNames()) {
int expected = B.hasOutDim(outDim) ? B.getBasis(inDim, i, outDim) : 0;
int actual = A.getBasis(inDim, i, outDim);
if (actual != expected)
return std::nullopt;
}
}

// Extract the candidate C bases from the remaining (shifted) entries in A.
for (int i = inB; i < inA; ++i) {
std::vector<int32_t> candidateBasis;
for (StringAttr outDim : llvm::make_first_range(cOutDimSizes)) {
int outB = B.hasOutDim(outDim) ? B.getOutDimSizeLog2(outDim) : 0;
int v = A.getBasis(inDim, i, outDim);

// The lower outB bits must be zero.
if ((v & ((1 << outB) - 1)) != 0)
return std::nullopt;
candidateBasis.push_back(v >> outB);
}
basesForDim.push_back(std::move(candidateBasis));
}
cBases[inDim] = basesForDim;
}

SmallVector<std::pair<StringAttr, int32_t>> COutDims;
for (auto [outDim, outC] : cOutDimSizes) {
COutDims.push_back({outDim, outC});
}
// If the layout A and B are surjective, then C should also be surjective.
LinearLayout C(std::move(cBases), COutDims,
/*requireSurjective=*/A.isSurjective() && B.isSurjective());
assert(B * C == A);
return C;
}

LinearLayout operator*(LinearLayout inner, LinearLayout outer) {
// Check that dims common to outer and inner have the same relative order.
auto inDims = supremum(llvm::to_vector(inner.getInDimNames()),
Expand Down
107 changes: 107 additions & 0 deletions unittest/Tools/LinearLayoutTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,113 @@ TEST(SupremumTest, ErrorOnInconsistentOrder) {
ASSERT_DEATH({ supremum(x, y); }, "Supremum does not exist");
}
#endif

TEST_F(LinearLayoutTest, DivideLeft_Basic) {
// Test division when A = B * C.
auto B = LinearLayout::identity1D(8, S("in"), S("out"));
auto C = LinearLayout::zeros1D(16, S("in"), S("out"));
auto isC = divideLeft(B * C, B);
EXPECT_TRUE(isC.has_value());
EXPECT_EQ(isC.value(), C);

auto isB = divideLeft(C * B, C);
EXPECT_TRUE(isB.has_value());
EXPECT_EQ(isB.value(), B);
}

TEST_F(LinearLayoutTest, DivideLeft_NonMatchingDims) {
// If B contains an extra input dimension not present in A, division should
// fail.
LinearLayout A = LinearLayout::identity1D(32, S("in"), S("out"));
LinearLayout B({{S("in"), {{1}, {2}, {4}, {8}}}, {S("extra"), {{0}}}},
{S("out")});
auto candidateOpt = divideLeft(A, B);
EXPECT_FALSE(candidateOpt.has_value());
}

TEST_F(LinearLayoutTest, DivideLeft_Simple) {
EXPECT_EQ(divideLeft(LinearLayout::identity1D(8, S("in"), S("out")),
LinearLayout::identity1D(4, S("in"), S("out"))),
LinearLayout::identity1D(2, S("in"), S("out")));

EXPECT_EQ(divideLeft(LinearLayout::identity1D(8, S("in"), S("out")),
LinearLayout::identity1D(8, S("in"), S("out"))),
LinearLayout::identity1D(1, S("in"), S("out")));
}

TEST_F(LinearLayoutTest, DivideLeft_2D) {
LinearLayout l1(
{
{S("in1"), {{1, 1}, {2, 2}, {0, 8}, {0, 4}}},
{S("in2"), {{0, 2}, {0, 1}}},
},
{S("out1"), S("out2")});
LinearLayout l2(
{
{S("in1"), {{1, 1}, {2, 2}}},
{S("in2"), {{0, 2}, {0, 1}}},
},
{S("out1"), S("out2")});
LinearLayout l3({{S("in1"), {{0, 2}, {0, 1}}}, {S("in2"), {}}},
{S("out1"), S("out2")});
ASSERT_EQ(l2 * l3, l1);
ASSERT_EQ(divideLeft(l1, l2).value(), l3);
}

TEST_F(LinearLayoutTest, DivideLeft_EliminateInDim) {
LinearLayout l1(
{
{S("in2"), {{0, 1}, {1, 0}}},
{S("in1"), {{2, 0}, {0, 2}}},
},
{S("out1"), S("out2")});
LinearLayout l2({{S("in2"), {{0, 1}, {1, 0}}}}, {S("out1"), S("out2")});
LinearLayout l3({{S("in2"), {}}, {S("in1"), {{1, 0}, {0, 1}}}},
{S("out1"), S("out2")});
ASSERT_EQ(l2 * l3, l1);
EXPECT_EQ(divideLeft(l1, l2).value(), l3);

LinearLayout l4({{S("in1"), {{0, 1}, {0, 2}}}, {S("in2"), {}}},
{S("out1"), S("out2")});
LinearLayout l5({{S("in1"), {{0, 1}, {0, 2}}}}, {S("out1"), S("out2")});
LinearLayout l6({{S("in1"), {}}, {S("in2"), {}}}, {S("out1"), S("out2")});
ASSERT_EQ(l5 * l6, l4);
EXPECT_EQ(divideLeft(l4, l5).value(), l6);

LinearLayout l7({{S("in1"), {}}, {S("in2"), {{0, 1}}}, {S("in3"), {}}},
{S("out1"), S("out2")});
LinearLayout l8({{S("in2"), {{0, 1}}}}, {S("out1"), S("out2")});
LinearLayout l9({{S("in1"), {}}, {S("in2"), {}}, {S("in3"), {}}},
{S("out1"), S("out2")});
ASSERT_EQ(l8 * l9, l7);
EXPECT_EQ(divideLeft(l7, l8).value(), l9);
}

TEST_F(LinearLayoutTest, DivideLeft_EliminateOutDim) {
LinearLayout l1(
{
{S("in2"), {{1, 0}, {1, 0}}},
{S("in1"), {{2, 0}, {0, 1}}},
},
{S("out1"), S("out2")});
LinearLayout l2({{S("in2"), {{1, 0}, {1, 0}}}}, {S("out1"), S("out2")});
LinearLayout l3({{S("in2"), {}}, {S("in1"), {{1, 0}, {0, 1}}}},
{S("out1"), S("out2")});
ASSERT_EQ(l2 * l3, l1);
EXPECT_EQ(divideLeft(l1, l2).value(), l3);

LinearLayout l4(
{
{S("in1"), {{0, 1}, {0, 2}}},
},
{S("out1"), S("out2")});
using BasesArray =
ArrayRef<std::pair<StringAttr, std::vector<std::vector<int32_t>>>>;
LinearLayout l5(BasesArray{}, {S("out1")});
LinearLayout l6({{S("in1"), {{0, 1}, {0, 2}}}}, {S("out1"), S("out2")});
ASSERT_EQ(l5 * l6, l4);
EXPECT_EQ(divideLeft(l4, l5).value(), l6);
}
} // anonymous namespace
} // namespace mlir::triton

Expand Down
Loading