From 0d283f8a39fdd890b9bb3a5fc882976fce0276f6 Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 23 Apr 2025 08:23:40 +0100 Subject: [PATCH 1/2] [LAYOUTS] Implement divideLeft Finally got around to implement `divideLeft`. I rescued and adapted the `divideRight` tests from https://github.com/triton-lang/triton/pull/5170 and added a few more. As discussed in the inline comment, there is some ambiguity when it comes to define the quotient if the quotient has dimensions of size one. You can basically keep them or remove them and both of them would give you a LinearLayout such that `A = B * C`. We choose to keep them, as removing them would make the behaviour too unpredictable. The invariant here is that the dimensions of the result of `leftDivide` are the same as the dimensions of the left input. The right input may have less dimensions than the right input tho. --- include/triton/Tools/LinearLayout.h | 16 +++++ lib/Tools/LinearLayout.cpp | 74 +++++++++++++++++++ unittest/Tools/LinearLayoutTest.cpp | 107 ++++++++++++++++++++++++++++ 3 files changed, 197 insertions(+) diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index 37a3d3de90de..595f0e583c2b 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -610,6 +610,22 @@ class LinearLayout { return *this; } + // Compute a C such that A = B * C if it exists. + // In other words, C = B^{-1} * A. + // Note that such a C exists iff (every pair of input/output dim of) A is + // of the form + // [[B, 0], + // [0, C]] + // as a matrix, whenever those dimensions are present in B. + // + // C will always have the same input/output dimensions as A. + // When there are dimensions of size 1 there is some ambiguity in the + // division, as in `operator*` we treat missing dimensions as dimensions + // of size 1 whenever it makes sense to do so. The rule that C has the + // same dimensions as A ensures that C is well-defined. + friend std::optional divideLeft(const LinearLayout &A, + const LinearLayout &B); + // Returns true if this layout acts trivially (as the identity) on the given // dimensions. This means that it's the identity on those dimensions, and it // does not map other dimensions onto those or these onto other dimensions. diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index 47255220c346..bd868c7df091 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -560,6 +560,80 @@ LinearLayout LinearLayout::concatOuts(const LinearLayout &other) const { /*requiresSurjective=*/false); } +std::optional divideLeft(const LinearLayout &A, + const LinearLayout &B) { + // Compute a C such that A = B * C if it exists. + // Note that such a C exists iff (every pair of input/output dim of) A is of + // the form + // [[B, 0], + // [0, C]] + // as a matrix, whenever those dimensions are present in B. + for (StringAttr dim : B.getInDimNames()) { + if (!llvm::is_contained(A.getInDimNames(), dim)) + return std::nullopt; + } + for (StringAttr dim : B.getOutDimNames()) { + if (!llvm::is_contained(A.getOutDimNames(), dim)) + return std::nullopt; + } + // Compute candidate C's log-sizes for output dimensions. + llvm::MapVector cOutDimSizes; + for (StringAttr outDim : A.getOutDimNames()) { + int outA = A.getOutDimSizeLog2(outDim); + int outB = B.hasOutDim(outDim) ? B.getOutDimSizeLog2(outDim) : 0; + int outC = outA - outB; + if (outC < 0) + return std::nullopt; + cOutDimSizes[outDim] = 1 << outC; + } + + LinearLayout::BasesT cBases; + for (StringAttr inDim : A.getInDimNames()) { + int inA = A.getInDimSizeLog2(inDim); + int inB = B.hasInDim(inDim) ? B.getInDimSizeLog2(inDim) : 0; + int inC = inA - inB; + if (inC < 0) + return std::nullopt; + + std::vector> basesForDim; + // Check that A’s first inB entries agree with B. + for (int i = 0; i < inB; ++i) { + for (StringAttr outDim : A.getOutDimNames()) { + int expected = B.hasOutDim(outDim) ? B.getBasis(inDim, i, outDim) : 0; + int actual = A.getBasis(inDim, i, outDim); + if (actual != expected) + return std::nullopt; + } + } + + // Extract the candidate C bases from the remaining (shifted) entries in A. + for (int i = inB; i < inA; ++i) { + std::vector candidateBasis; + for (StringAttr outDim : llvm::make_first_range(cOutDimSizes)) { + int outB = B.hasOutDim(outDim) ? B.getOutDimSizeLog2(outDim) : 0; + int v = A.getBasis(inDim, i, outDim); + + // The lower outB bits must be zero. + if ((v & ((1 << outB) - 1)) != 0) + return std::nullopt; + candidateBasis.push_back(v >> outB); + } + basesForDim.push_back(std::move(candidateBasis)); + } + cBases[inDim] = basesForDim; + } + + SmallVector> COutDims; + for (auto [outDim, outC] : cOutDimSizes) { + COutDims.push_back({outDim, outC}); + } + // If the layout A and B are surjective, then C should also be surjective. + LinearLayout C(std::move(cBases), COutDims, + /*requireSurjective=*/A.isSurjective() && B.isSurjective()); + assert(B * C == A); + return C; +} + LinearLayout operator*(LinearLayout inner, LinearLayout outer) { // Check that dims common to outer and inner have the same relative order. auto inDims = supremum(llvm::to_vector(inner.getInDimNames()), diff --git a/unittest/Tools/LinearLayoutTest.cpp b/unittest/Tools/LinearLayoutTest.cpp index 4f89bc9c0ca0..9756eae79fa2 100644 --- a/unittest/Tools/LinearLayoutTest.cpp +++ b/unittest/Tools/LinearLayoutTest.cpp @@ -921,6 +921,113 @@ TEST(SupremumTest, ErrorOnInconsistentOrder) { ASSERT_DEATH({ supremum(x, y); }, "Supremum does not exist"); } #endif + +TEST_F(LinearLayoutTest, DivideLeft_Basic) { + // Test division when A = B * C. + auto B = LinearLayout::identity1D(8, S("in"), S("out")); + auto C = LinearLayout::zeros1D(16, S("in"), S("out")); + auto isC = divideLeft(B * C, B); + EXPECT_TRUE(isC.has_value()); + EXPECT_EQ(isC.value(), C); + + auto isB = divideLeft(C * B, C); + EXPECT_TRUE(isB.has_value()); + EXPECT_EQ(isB.value(), B); +} + +TEST_F(LinearLayoutTest, DivideLeft_NonMatchingDims) { + // If B contains an extra input dimension not present in A, division should + // fail. + LinearLayout A = LinearLayout::identity1D(32, S("in"), S("out")); + LinearLayout B({{S("in"), {{1}, {2}, {4}, {8}}}, {S("extra"), {{0}}}}, + {S("out")}); + auto candidateOpt = divideLeft(A, B); + EXPECT_FALSE(candidateOpt.has_value()); +} + +TEST_F(LinearLayoutTest, DivideLeft_Simple) { + EXPECT_EQ(divideLeft(LinearLayout::identity1D(8, S("in"), S("out")), + LinearLayout::identity1D(4, S("in"), S("out"))), + LinearLayout::identity1D(2, S("in"), S("out"))); + + EXPECT_EQ(divideLeft(LinearLayout::identity1D(8, S("in"), S("out")), + LinearLayout::identity1D(8, S("in"), S("out"))), + LinearLayout::identity1D(1, S("in"), S("out"))); +} + +TEST_F(LinearLayoutTest, DivideLeft_2D) { + LinearLayout l1( + { + {S("in1"), {{1, 1}, {2, 2}, {0, 8}, {0, 4}}}, + {S("in2"), {{0, 2}, {0, 1}}}, + }, + {S("out1"), S("out2")}); + LinearLayout l2( + { + {S("in1"), {{1, 1}, {2, 2}}}, + {S("in2"), {{0, 2}, {0, 1}}}, + }, + {S("out1"), S("out2")}); + LinearLayout l3({{S("in1"), {{0, 2}, {0, 1}}}, {S("in2"), {}}}, + {S("out1"), S("out2")}); + ASSERT_EQ(l2 * l3, l1); + ASSERT_EQ(divideLeft(l1, l2).value(), l3); +} + +TEST_F(LinearLayoutTest, DivideLeft_EliminateInDim) { + LinearLayout l1( + { + {S("in2"), {{0, 1}, {1, 0}}}, + {S("in1"), {{2, 0}, {0, 2}}}, + }, + {S("out1"), S("out2")}); + LinearLayout l2({{S("in2"), {{0, 1}, {1, 0}}}}, {S("out1"), S("out2")}); + LinearLayout l3({{S("in2"), {}}, {S("in1"), {{1, 0}, {0, 1}}}}, + {S("out1"), S("out2")}); + ASSERT_EQ(l2 * l3, l1); + EXPECT_EQ(divideLeft(l1, l2).value(), l3); + + LinearLayout l4({{S("in1"), {{0, 1}, {0, 2}}}, {S("in2"), {}}}, + {S("out1"), S("out2")}); + LinearLayout l5({{S("in1"), {{0, 1}, {0, 2}}}}, {S("out1"), S("out2")}); + LinearLayout l6({{S("in1"), {}}, {S("in2"), {}}}, {S("out1"), S("out2")}); + ASSERT_EQ(l5 * l6, l4); + EXPECT_EQ(divideLeft(l4, l5).value(), l6); + + LinearLayout l7({{S("in1"), {}}, {S("in2"), {{0, 1}}}, {S("in3"), {}}}, + {S("out1"), S("out2")}); + LinearLayout l8({{S("in2"), {{0, 1}}}}, {S("out1"), S("out2")}); + LinearLayout l9({{S("in1"), {}}, {S("in2"), {}}, {S("in3"), {}}}, + {S("out1"), S("out2")}); + ASSERT_EQ(l8 * l9, l7); + EXPECT_EQ(divideLeft(l7, l8).value(), l9); +} + +TEST_F(LinearLayoutTest, DivideLeft_EliminateOutDim) { + LinearLayout l1( + { + {S("in2"), {{1, 0}, {1, 0}}}, + {S("in1"), {{2, 0}, {0, 1}}}, + }, + {S("out1"), S("out2")}); + LinearLayout l2({{S("in2"), {{1, 0}, {1, 0}}}}, {S("out1"), S("out2")}); + LinearLayout l3({{S("in2"), {}}, {S("in1"), {{1, 0}, {0, 1}}}}, + {S("out1"), S("out2")}); + ASSERT_EQ(l2 * l3, l1); + EXPECT_EQ(divideLeft(l1, l2).value(), l3); + + LinearLayout l4( + { + {S("in1"), {{0, 1}, {0, 2}}}, + }, + {S("out1"), S("out2")}); + using BasesArray = + ArrayRef>>>; + LinearLayout l5(BasesArray{}, {S("out1")}); + LinearLayout l6({{S("in1"), {{0, 1}, {0, 2}}}}, {S("out1"), S("out2")}); + ASSERT_EQ(l5 * l6, l4); + EXPECT_EQ(divideLeft(l4, l5).value(), l6); +} } // anonymous namespace } // namespace mlir::triton From 0349b73ae4d436424f517195d2ad27ac25d87e32 Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 23 Apr 2025 15:31:14 +0100 Subject: [PATCH 2/2] [LAYOUTS] Use divideLeft on layout inference Now the duality between forward and backward pass should be crystal clear --- include/triton/Tools/LinearLayout.h | 2 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 22 +++------------------- 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index 595f0e583c2b..ee8fdd70df5f 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -503,7 +503,7 @@ class LinearLayout { // This only works across the first (i.e. the most-minor) dimension of in/out. // If you want it to work across more dimensions, flatten the layout. // - // TODO(jlebar): Replace with divideLeft. + // TODO: Replace the uses with flattenIns/Outs + divideLeft. int32_t getNumConsecutiveInOut() const; // Reorders the in/out dimensions of the layout. This is mostly cosmetic diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 6be47c9c34ff..5b46a765f516 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -428,28 +428,12 @@ LogicalResult tryJoinOnAxis(MLIRContext *ctx, const LinearLayout &inLl, std::optional loc) { auto kRegister = StringAttr::get(ctx, "register"); auto outDims = llvm::to_vector(inLl.getOutDimNames()); + auto split = LinearLayout::identity1D(2, kRegister, outDims[axis]); if (fwdInference) { - auto split = LinearLayout::identity1D(2, kRegister, outDims[axis]); outLl = split * inLl; } else { - // TODO This requires a division algorithm! - // Implement manually ll.divideLeft(split) - auto contiguousElems = - LinearEncodingAttr::get(ctx, inLl).getContigPerThread(); - if (contiguousElems[axis] > 1) { - LinearLayout::BasesT newBases; - for (const auto &basesDim : inLl.getBases()) { - std::vector> newBasesDim; - for (auto base : basesDim.second) { - if (base[axis] == 1) { - continue; - } - base[axis] /= 2; - newBasesDim.push_back(std::move(base)); - } - newBases.insert({basesDim.first, std::move(newBasesDim)}); - } - outLl = LinearLayout(std::move(newBases), std::move(outDims)); + if (auto div = divideLeft(inLl, split)) { + outLl = *div; } else { return emitOptionalError(loc, "Fp4ToFpOp/SplitOp requires at least 2 elements "