diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index 37a3d3de90de..595f0e583c2b 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -610,6 +610,22 @@ class LinearLayout { return *this; } + // Compute a C such that A = B * C if it exists. + // In other words, C = B^{-1} * A. + // Note that such a C exists iff (every pair of input/output dim of) A is + // of the form + // [[B, 0], + // [0, C]] + // as a matrix, whenever those dimensions are present in B. + // + // C will always have the same input/output dimensions as A. + // When there are dimensions of size 1 there is some ambiguity in the + // division, as in `operator*` we treat missing dimensions as dimensions + // of size 1 whenever it makes sense to do so. The rule that C has the + // same dimensions as A ensures that C is well-defined. + friend std::optional divideLeft(const LinearLayout &A, + const LinearLayout &B); + // Returns true if this layout acts trivially (as the identity) on the given // dimensions. This means that it's the identity on those dimensions, and it // does not map other dimensions onto those or these onto other dimensions. diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index 47255220c346..bd868c7df091 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -560,6 +560,80 @@ LinearLayout LinearLayout::concatOuts(const LinearLayout &other) const { /*requiresSurjective=*/false); } +std::optional divideLeft(const LinearLayout &A, + const LinearLayout &B) { + // Compute a C such that A = B * C if it exists. + // Note that such a C exists iff (every pair of input/output dim of) A is of + // the form + // [[B, 0], + // [0, C]] + // as a matrix, whenever those dimensions are present in B. + for (StringAttr dim : B.getInDimNames()) { + if (!llvm::is_contained(A.getInDimNames(), dim)) + return std::nullopt; + } + for (StringAttr dim : B.getOutDimNames()) { + if (!llvm::is_contained(A.getOutDimNames(), dim)) + return std::nullopt; + } + // Compute candidate C's log-sizes for output dimensions. + llvm::MapVector cOutDimSizes; + for (StringAttr outDim : A.getOutDimNames()) { + int outA = A.getOutDimSizeLog2(outDim); + int outB = B.hasOutDim(outDim) ? B.getOutDimSizeLog2(outDim) : 0; + int outC = outA - outB; + if (outC < 0) + return std::nullopt; + cOutDimSizes[outDim] = 1 << outC; + } + + LinearLayout::BasesT cBases; + for (StringAttr inDim : A.getInDimNames()) { + int inA = A.getInDimSizeLog2(inDim); + int inB = B.hasInDim(inDim) ? B.getInDimSizeLog2(inDim) : 0; + int inC = inA - inB; + if (inC < 0) + return std::nullopt; + + std::vector> basesForDim; + // Check that A’s first inB entries agree with B. + for (int i = 0; i < inB; ++i) { + for (StringAttr outDim : A.getOutDimNames()) { + int expected = B.hasOutDim(outDim) ? B.getBasis(inDim, i, outDim) : 0; + int actual = A.getBasis(inDim, i, outDim); + if (actual != expected) + return std::nullopt; + } + } + + // Extract the candidate C bases from the remaining (shifted) entries in A. + for (int i = inB; i < inA; ++i) { + std::vector candidateBasis; + for (StringAttr outDim : llvm::make_first_range(cOutDimSizes)) { + int outB = B.hasOutDim(outDim) ? B.getOutDimSizeLog2(outDim) : 0; + int v = A.getBasis(inDim, i, outDim); + + // The lower outB bits must be zero. + if ((v & ((1 << outB) - 1)) != 0) + return std::nullopt; + candidateBasis.push_back(v >> outB); + } + basesForDim.push_back(std::move(candidateBasis)); + } + cBases[inDim] = basesForDim; + } + + SmallVector> COutDims; + for (auto [outDim, outC] : cOutDimSizes) { + COutDims.push_back({outDim, outC}); + } + // If the layout A and B are surjective, then C should also be surjective. + LinearLayout C(std::move(cBases), COutDims, + /*requireSurjective=*/A.isSurjective() && B.isSurjective()); + assert(B * C == A); + return C; +} + LinearLayout operator*(LinearLayout inner, LinearLayout outer) { // Check that dims common to outer and inner have the same relative order. auto inDims = supremum(llvm::to_vector(inner.getInDimNames()), diff --git a/unittest/Tools/LinearLayoutTest.cpp b/unittest/Tools/LinearLayoutTest.cpp index 4f89bc9c0ca0..9756eae79fa2 100644 --- a/unittest/Tools/LinearLayoutTest.cpp +++ b/unittest/Tools/LinearLayoutTest.cpp @@ -921,6 +921,113 @@ TEST(SupremumTest, ErrorOnInconsistentOrder) { ASSERT_DEATH({ supremum(x, y); }, "Supremum does not exist"); } #endif + +TEST_F(LinearLayoutTest, DivideLeft_Basic) { + // Test division when A = B * C. + auto B = LinearLayout::identity1D(8, S("in"), S("out")); + auto C = LinearLayout::zeros1D(16, S("in"), S("out")); + auto isC = divideLeft(B * C, B); + EXPECT_TRUE(isC.has_value()); + EXPECT_EQ(isC.value(), C); + + auto isB = divideLeft(C * B, C); + EXPECT_TRUE(isB.has_value()); + EXPECT_EQ(isB.value(), B); +} + +TEST_F(LinearLayoutTest, DivideLeft_NonMatchingDims) { + // If B contains an extra input dimension not present in A, division should + // fail. + LinearLayout A = LinearLayout::identity1D(32, S("in"), S("out")); + LinearLayout B({{S("in"), {{1}, {2}, {4}, {8}}}, {S("extra"), {{0}}}}, + {S("out")}); + auto candidateOpt = divideLeft(A, B); + EXPECT_FALSE(candidateOpt.has_value()); +} + +TEST_F(LinearLayoutTest, DivideLeft_Simple) { + EXPECT_EQ(divideLeft(LinearLayout::identity1D(8, S("in"), S("out")), + LinearLayout::identity1D(4, S("in"), S("out"))), + LinearLayout::identity1D(2, S("in"), S("out"))); + + EXPECT_EQ(divideLeft(LinearLayout::identity1D(8, S("in"), S("out")), + LinearLayout::identity1D(8, S("in"), S("out"))), + LinearLayout::identity1D(1, S("in"), S("out"))); +} + +TEST_F(LinearLayoutTest, DivideLeft_2D) { + LinearLayout l1( + { + {S("in1"), {{1, 1}, {2, 2}, {0, 8}, {0, 4}}}, + {S("in2"), {{0, 2}, {0, 1}}}, + }, + {S("out1"), S("out2")}); + LinearLayout l2( + { + {S("in1"), {{1, 1}, {2, 2}}}, + {S("in2"), {{0, 2}, {0, 1}}}, + }, + {S("out1"), S("out2")}); + LinearLayout l3({{S("in1"), {{0, 2}, {0, 1}}}, {S("in2"), {}}}, + {S("out1"), S("out2")}); + ASSERT_EQ(l2 * l3, l1); + ASSERT_EQ(divideLeft(l1, l2).value(), l3); +} + +TEST_F(LinearLayoutTest, DivideLeft_EliminateInDim) { + LinearLayout l1( + { + {S("in2"), {{0, 1}, {1, 0}}}, + {S("in1"), {{2, 0}, {0, 2}}}, + }, + {S("out1"), S("out2")}); + LinearLayout l2({{S("in2"), {{0, 1}, {1, 0}}}}, {S("out1"), S("out2")}); + LinearLayout l3({{S("in2"), {}}, {S("in1"), {{1, 0}, {0, 1}}}}, + {S("out1"), S("out2")}); + ASSERT_EQ(l2 * l3, l1); + EXPECT_EQ(divideLeft(l1, l2).value(), l3); + + LinearLayout l4({{S("in1"), {{0, 1}, {0, 2}}}, {S("in2"), {}}}, + {S("out1"), S("out2")}); + LinearLayout l5({{S("in1"), {{0, 1}, {0, 2}}}}, {S("out1"), S("out2")}); + LinearLayout l6({{S("in1"), {}}, {S("in2"), {}}}, {S("out1"), S("out2")}); + ASSERT_EQ(l5 * l6, l4); + EXPECT_EQ(divideLeft(l4, l5).value(), l6); + + LinearLayout l7({{S("in1"), {}}, {S("in2"), {{0, 1}}}, {S("in3"), {}}}, + {S("out1"), S("out2")}); + LinearLayout l8({{S("in2"), {{0, 1}}}}, {S("out1"), S("out2")}); + LinearLayout l9({{S("in1"), {}}, {S("in2"), {}}, {S("in3"), {}}}, + {S("out1"), S("out2")}); + ASSERT_EQ(l8 * l9, l7); + EXPECT_EQ(divideLeft(l7, l8).value(), l9); +} + +TEST_F(LinearLayoutTest, DivideLeft_EliminateOutDim) { + LinearLayout l1( + { + {S("in2"), {{1, 0}, {1, 0}}}, + {S("in1"), {{2, 0}, {0, 1}}}, + }, + {S("out1"), S("out2")}); + LinearLayout l2({{S("in2"), {{1, 0}, {1, 0}}}}, {S("out1"), S("out2")}); + LinearLayout l3({{S("in2"), {}}, {S("in1"), {{1, 0}, {0, 1}}}}, + {S("out1"), S("out2")}); + ASSERT_EQ(l2 * l3, l1); + EXPECT_EQ(divideLeft(l1, l2).value(), l3); + + LinearLayout l4( + { + {S("in1"), {{0, 1}, {0, 2}}}, + }, + {S("out1"), S("out2")}); + using BasesArray = + ArrayRef>>>; + LinearLayout l5(BasesArray{}, {S("out1")}); + LinearLayout l6({{S("in1"), {{0, 1}, {0, 2}}}}, {S("out1"), S("out2")}); + ASSERT_EQ(l5 * l6, l4); + EXPECT_EQ(divideLeft(l4, l5).value(), l6); +} } // anonymous namespace } // namespace mlir::triton