diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index cfc4c0d13bbe..9ddec8881269 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -681,13 +681,6 @@ class LinearLayout { // (i.e. every input bit affects the output). llvm::MapVector getFreeVariableMasks() const; - // Increase an input dimension without affecting the output dimension. The - // added free variables are mapped to 0, ensuring that the new input - // dimensions correspond directly to the existing output space. The function - // errors out if `newInDimSize` is less than the current size or the new size - // is not a power of 2. - LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const; - std::string toString() const; friend bool operator==(LinearLayout lhs, LinearLayout rhs); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index e6d921464bc9..3a8be9ee3347 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -663,42 +663,8 @@ std::optional minimalCvtLayout(RankedTensorType srcTy, StringAttr kLane = StringAttr::get(ctx, "lane"); StringAttr kWarp = StringAttr::get(ctx, "warp"); StringAttr kBlock = StringAttr::get(ctx, "block"); - auto numSrcRegs = srcLayout->getInDimSize(kRegister); - auto numDstRegs = dstLayout->getInDimSize(kRegister); - // The `invertAndCompose` function will generate a layout that is injective - // by assigning new output dimensions to free variables. For instance, - // consider a scenario where `srcLayout` has a free variable in the lane - // dimension, while `dstLayout` has two free variables in the lane - // dimension and also a larger number of registers. - // The injective form of `srcLayout` will add only a single additional row - // to the transformation matrix, whereas the injective form of `dstLayout` - // will add two additional rows. This discrepancy causes misleading results - // because the matrices end up with a different number of rows. - // - // Take `dstLayout ⋅ srcLayout^-1` as an example: - // - // - `injective(dstLayout)`: [n, m] → [n + 2, m] - // - `injective(srcLayout)`: [n, m] → [n + 1, m] - // - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1] - // - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n + - // 1] → [n + 2, n + 1] - // - // Here, the `(n + 1)`-th row added by `dstLayout` represents the free - // variable in registers, and the `(n + 2)`-th row represents the free - // variable in lanes. However, the `(n + 1)`-th row added by `srcLayout` - // represents the free variable in lanes. As a result, the `(n + 1)`-th row - // in two layouts do not correspond to the same free variable. - // - // To address this issue, we pad the free variables in `srcLayout` and - // `dstLayout` to ensure they have the same number of registers. This - // guarantees that the resulting matrices have the same number of rows, - // ensuring consistency in the composition process. - auto numRegs = std::max(numSrcRegs, numDstRegs); - auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs); - auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs); - // comp describes the layout function to create dst from src. - LinearLayout comp = - dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs); + + auto comp = dstLayout->invertAndCompose(*srcLayout); // We try to quotient by the largest subspace first auto dims = SmallVector{"block", "warp", "lane", "register"}; for (auto dim : dims) { diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 84d30097cb70..7e8f6b783609 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -315,14 +315,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // TODO(Keren): implement warp shuffle instead of using the general // approach that uses shared memory return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); - } else if (llvm::is_contained(dims, kRegister) || - dstLayout.getInDimSize(kRegister) != - srcLayout.getInDimSize(kRegister)) { + } else if (llvm::is_contained(dims, kRegister)) { // Case 4. Transfer between values in the same thread, in which case we // simply reorder the elements of adaptor.getSrc(). - return transferWithinThread( - op, dstLayout.getFreeVariableMasks()[kRegister], - dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter); + return transferWithinThread(op, *conversion, adaptor, rewriter); } else { // Cast 5. The two layouts are equivalent. We should probably remove // these in RemoveLayoutConversion. @@ -332,8 +328,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } LogicalResult - transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs, - const LinearLayout &conversion, OpAdaptor adaptor, + transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion, + OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); @@ -343,16 +339,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion auto srcTy = op.getSrc().getType(); auto dstTy = op.getType(); auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - SmallVector outVals(numRegs); - for (int i = 0; i < numRegs; i++) { - // Remove free masks from the register index - // For example, if idx = 0b00111, and masks = 0b00100, then we get - // 0b00011. It means that register 7 (0b111) has the same value as - // register 3 (0b011). - auto idx = i & (~regMasks); - auto srcIdx = conversion.hasInDim(kRegister) - ? conversion.apply({{kRegister, idx}}).begin()->second - : idx; + SmallVector outVals(conversion.getInDimSize(kRegister)); + for (int i = 0; i < outVals.size(); i++) { + auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second; outVals[i] = inVals[srcIdx]; } Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index 3a81231ac863..0ab563908a60 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -112,30 +112,6 @@ std::unique_ptr getMatrix(const LinearLayout &layout) { return m; } -// Get a matrix for `layout` with its codomain expanded so it's injective, i.e. -// each input element maps to a unique output element. We do this by finding -// columns that are equal to 0 and adding a new row with a 1 in that column. -std::tuple, int /*numRows*/, int /*numCols*/> -getInjectiveMat(const LinearLayout &layout) { - int numRows = layout.getTotalOutDimSizeLog2(); - int numCols = layout.getTotalInDimSizeLog2(); - std::unique_ptr mat = getMatrix(layout); - - // Bits of mat or-reduced along the columns (so there's just one row). - uint64_t colBits = 0; - for (int r = 0; r < numRows; r++) { - colBits |= mat[r]; - } - auto expanded = std::unique_ptr(new uint64_t[numRows + numCols]); - std::memcpy(expanded.get(), mat.get(), numRows * sizeof(uint64_t)); - for (int c = 0; c < numCols; c++) { - if ((colBits & (1 << c)) == 0) { - expanded[numRows++] = (1 << c); - } - } - return std::make_tuple(std::move(expanded), numRows, numCols); -} - // Compute the rank of the matrix formed by taking the bases for the given // outDim as columns. In other words, finds the number of linearly-independent // bases for this output dimension. @@ -804,118 +780,179 @@ LinearLayout LinearLayout::compose(const LinearLayout &outer) const { compositionIsSurjective); } -LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const { - assertDimsEqualIgnoringOrder(getOutDimNames(), outer.getOutDimNames()); - for (StringAttr outDim : getOutDimNames()) { - assert(getOutDimSize(outDim) <= outer.getOutDimSize(outDim)); +namespace { +std::unique_ptr concatMatrices(const LinearLayout &A, + const LinearLayout &B) { + // In plain words, "convert_layout does not change the shape of a tensor" + assert(A.getTotalOutDimSizeLog2() == B.getTotalOutDimSizeLog2() && + "Matrices must have the same number of output dimensions"); + int numRows = A.getTotalOutDimSizeLog2(); + int numColsA = A.getTotalInDimSizeLog2(); + + // rref expects the lower bits to be the lower indices of the matrix + auto concat = getMatrix(A); + auto BMat = getMatrix(B); + for (int r = 0; r < numRows; r++) { + concat[r] |= BMat[r] << numColsA; } - assert(outer.isSurjective()); + return concat; +} - // Make both `this` and `outer` injective. We need to do this on the - // `outer` layout because we can't invert a non-injective function. We - // choose to do so on the `this` layout as well. The rest of the comment - // explains why we make that choice. - // - // Recall from the header that C = A.invertAndCompose(B) just means that - // A(x) = B(C(x)). - // - // Sometimes we may have a choice of multiple values for a particular - // C(x). For example, if A(1) = B(0) = B(1) = 0, then C(1) can be either 0 - // or 1. - // - // We want to choose C such that C(x) != 0 where possible. For example, - // suppose we are transferring from registers to registers and we have the - // following layouts. - // - // A(thread=1, block=0) = 1 - // A(thread=2, block=0) = 2 - // A(thread=0, block=1) = 0 - // - // B(thread=1, block=0) = 2 - // B(thread=2, block=0) = 1 - // B(thread=0, block=1) = 0 - // - // Notice that A and B both have the same data in each of their two - // blocks. So if we want to transfer from A to B, we don't need to cross - // blocks, which is expensive. We want A.invertAndCompose(B) to reflect - // that choice. - // - // Let A' be A with the last line changed to "=4", and similarly for B'. - // When transferring from A' to B', we can't cross blocks even if we wanted - // to, because the two blocks now have different data. But also, any - // mapping of thread+block from A' to B' is also valid for mapping from A - // to B. - // - // Thus making A and B injective encodes our desire not to cross blocks, - // or more generally our desire that C(x) != 0 where possible. - auto [matThis, numRowsThis, numColsThis] = getInjectiveMat(*this); - auto [matOuter, numRowsOuter, numColsOuter] = getInjectiveMat( - outer.transposeOuts(llvm::to_vector(this->getOutDimNames()))); - - // Concatenate `matOuter` and `matThis` horizontally (i.e. `matThis` - // is to the right of `matOuter`). - int combinedNumRows = std::max(numRowsThis, numRowsOuter); - int combinedNumCols = numColsThis + numColsOuter; - assert(combinedNumCols <= 64 && "Can't handle huge layouts"); - - std::unique_ptr m(new uint64_t[combinedNumRows]()); - for (int r = 0; r < numRowsOuter; r++) { - m[r] = matOuter[r]; - } - for (int r = 0; r < numRowsThis; r++) { - m[r] |= matThis[r] << numColsOuter; - } - - // Perform Gaussian elimination on `m`. Because `outer` was modified to - // be bijective, the first half of the matrix should be the identity - // matrix. The remaining half are the bases for the combined - // transformation. - // - // `stride` is specified in number of 64-bit words per row, and we pack - // our matrix so that there's only one uint64_t per row. - f2reduce::inplace_rref_strided(m.get(), combinedNumRows, combinedNumCols, +LinearLayout lstsq(const LinearLayout &A, const LinearLayout &B) { + // Solve the least square system AX = B for A = outer, B = *this + // and return the least square solution X of minimal norm + // A and B may not be surjective, but we assume that Im(B) \subset Im(A) + // Sketch of the algorithm: + // https://github.com/triton-lang/triton/pull/5309#discussion_r1869084111 + int numRows = A.getTotalOutDimSizeLog2(); + int numColsA = A.getTotalInDimSizeLog2(); + int numColsB = B.getTotalInDimSizeLog2(); + int numCols = numColsA + numColsB; + std::unique_ptr combinedMat = concatMatrices(A, B); + f2reduce::inplace_rref_strided(combinedMat.get(), numRows, numCols, /*stride=*/1); - // Check that the first half of the matrix is indeed the identity. - for (int r = 0; r < std::min(numRowsOuter, numColsOuter); r++) { - for (int c = 0; c < std::min(numColsOuter, numRowsOuter); c++) { - if (((m[r] >> c) & 1) != (r == c ? 1 : 0)) { - llvm::report_fatal_error("First half of the matrix was not the " - "identity, bug in invertAndCompose"); - } + // Compute the pivot columns + // Since A and B have the same image, each row will either have a pivot + // or will be all zeros + SmallVector pivotCols; + for (int r = 0; r < numRows; r++) { + auto row = combinedMat[r]; + if (row == 0) { + continue; } + int c = __builtin_ctzll(row); + assert(c < numColsA && "Precondition broken. Im(B) not contained in Im(A)"); + assert(pivotCols.empty() || + pivotCols.back() < c && "Pivot columns are not in increasing order"); + pivotCols.push_back(c); + } + + // Extract A^{-1}B and complete the matrix using zeros + std::unique_ptr retMat(new uint64_t[numColsA]()); + int j = 0; + for (int r = 0; r < numColsA; r++) { + auto isPivot = j < pivotCols.size() && pivotCols[j] == r; + retMat[r] = isPivot ? combinedMat[j++] >> numColsA : 0; } // We need names for the in/out dim of the flattened layout we're going to // read off from `m`. These could be anything, doesn't matter. - StringAttr inDim1D = *getInDimNames().begin(); - StringAttr outDim1D = *getOutDimNames().begin(); + StringAttr inDim1D = *A.getInDimNames().begin(); + StringAttr outDim1D = *A.getOutDimNames().begin(); // Read off the new bases. These are for a flattened 1D -> 1D - // transformation from `this`'s in-dims to `outer`'s in-dims. - BasesT newBases; - auto &bs = newBases[inDim1D]; - for (int c = 0; c < numColsThis; c++) { + LinearLayout::BasesT retBases; + auto &bs = retBases[inDim1D]; + for (int c = 0; c < numColsB; c++) { int32_t basis = 0; - for (int r = 0; r < numRowsOuter; r++) { - basis |= (m[r] >> (numColsOuter + c) & 1) << r; + for (int r = 0; r < numColsA; r++) { + basis |= (retMat[r] >> c & 1) << r; } bs.push_back({basis}); } - LinearLayout flatComposed(std::move(newBases), - {{outDim1D, outer.getTotalInDimSize()}}, + LinearLayout retFlattened(std::move(retBases), + {{outDim1D, A.getTotalInDimSize()}}, /*requireSurjective=*/false); SmallVector> retInDims; SmallVector> retOutDims; - for (StringAttr dim : getInDimNames()) { - retInDims.push_back({dim, getInDimSize(dim)}); + for (StringAttr dim : B.getInDimNames()) { + retInDims.push_back({dim, B.getInDimSize(dim)}); + } + for (StringAttr dim : A.getInDimNames()) { + retOutDims.push_back({dim, A.getInDimSize(dim)}); + } + return retFlattened.reshapeIns(retInDims).reshapeOuts(retOutDims); +} + +} // namespace + +LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const { + // TODO(Lezcano) Make friend and perhaps rename to `convertFrom` or `lstsq` + // For this, we need to implement our LLVM lowerings by inverting the "outer" + // layout, and then iterating over the elements from the "this" layout and + // fetching the corresponding element from the "outer" layout. This exercises + // the broadcasting that we incentivise via choosing the minimum norm solution + // in lstsq. + + // The order of dims does not matter. We choose to transpose outer + auto outDims = llvm::to_vector(getOutDimNames()); + assertDimsEqualIgnoringOrder(outDims, outer.getOutDimNames()); + const auto &B = *this; + const auto A = outer.transposeOuts(outDims); + for (auto dim : outDims) { + assert(A.getOutDimSize(dim) == B.getOutDimSize(dim) && + "Convert layout does not change the shape of a tensor"); + } + + // We'll write A^{-1} to mean the inverse or the pseudo-inverse of A + // We are computing A^{-1}B so A must be surjective so that + // it has a left inverse. + assert(A.isSurjective()); + + // Broadcasting heuristic + // Imagine we have two layouts with `warps = [[0, 0],  [0, 0]]` + // (broadcasting) on both layouts. We could map any warp to any warp in the + // conversion. Now, we want to map them as the identity map, to mark that + // nothing needs to be done there (`lstsq` would map all the warps to the + // zero warp, minimum norm solution). The heuristic here is as follows: + // - If a dimension is the same for both layouts, we want to map it as the + // identity + // Equivalently, we don't add it to the conversion + // - Otherwise, we just call lstsq (i.e. map all the equivalent elements + // to the same input element) to take advantage of broadcasting in shared + // memory and avoid saving repeated elements in shared memory + SmallVector identityDims; + for (auto dim : A.getInDimNames()) { + if (B.hasInDim(dim) && + A.sublayout(dim, outDims) == B.sublayout(dim, outDims)) { + identityDims.push_back(dim); + } + } + SmallVector ANonIdentityInDims; + SmallVector BNonIdentityInDims; + for (auto dim : A.getInDimNames()) { + if (!llvm::is_contained(identityDims, dim)) { + ANonIdentityInDims.push_back(dim); + } } - for (StringAttr dim : outer.getInDimNames()) { - retOutDims.push_back({dim, outer.getInDimSize(dim)}); + for (auto dim : B.getInDimNames()) { + if (!llvm::is_contained(identityDims, dim)) { + BNonIdentityInDims.push_back(dim); + } + } + + auto AReduced = A.sublayout(ANonIdentityInDims, outDims); + auto BReduced = B.sublayout(BNonIdentityInDims, outDims); + + // If one is empty, the other must be empty as well + assert((AReduced == LinearLayout::empty()) == + (BReduced == LinearLayout::empty())); + bool isEmpty = AReduced == LinearLayout::empty(); + + auto ret = isEmpty ? LinearLayout::empty() : lstsq(AReduced, BReduced); + + // TODO(Lezcano): We should return the reduced layout instead of re-adding the + // identity maps. With this, we'll be able to kill `minimalCvtLayout` + + // Add the identity maps for the dimensions that are the same for both layouts + for (auto dim : identityDims) { + ret *= LinearLayout::identity1D(A.getInDimSize(dim), dim, dim); } - return flatComposed.reshapeIns(retInDims).reshapeOuts(retOutDims); + + // Reshape the result + SmallVector> inDimsA; + SmallVector> inDimsB; + for (auto dim : A.getInDimNames()) { + inDimsA.push_back({dim, A.getInDimSize(dim)}); + } + for (auto dim : B.getInDimNames()) { + inDimsB.push_back({dim, B.getInDimSize(dim)}); + } + ret = ret.reshapeIns(inDimsB).reshapeOuts(inDimsA); + return ret; } llvm::MapVector @@ -1004,21 +1041,6 @@ bool LinearLayout::equalIgnoringOutDimSizes(const LinearLayout &other) const { return true; } -LinearLayout LinearLayout::resize(StringAttr inDim, - int32_t newInDimSize) const { - BasesT bases = getBases(); - assert(bases.contains(inDim) && "inDim not in layout"); - assert(llvm::isPowerOf2_32(newInDimSize) && - "newInDimSize must be a power of 2"); - assert(newInDimSize >= getInDimSize(inDim) && - "newInDimSize must be >= old size"); - auto numFreeVariables = llvm::Log2_32(newInDimSize) - getInDimSizeLog2(inDim); - for (int i = 0; i < numFreeVariables; i++) { - bases[inDim].push_back(std::vector(getNumOutDims(), 0)); - } - return LinearLayout(std::move(bases), llvm::to_vector(getOutDimNames())); -} - std::string LinearLayout::toString() const { // Start with a newline because we print out a bulleted list; it doesn't // make sense for the first line of this list to be on the same line as diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 61911356af65..a97ac476cbad 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1698,8 +1698,7 @@ module attributes {"ttg.target" = "cuda:75", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK-LABEL: convert_single_element // CHECK-NOT: llvm.store // CHECK-NOT: llvm.load - // CHECK: llvm.insertvalue - // CHECK: llvm.extractvalue + // CHECK: llvm.return tt.func public @convert_single_element() attributes {noinline = false} { %cst = arith.constant dense<1.000000e+03> : tensor<1xf32, #blocked1> %0 = ttg.convert_layout %cst : tensor<1xf32, #blocked1> -> tensor<1xf32, #blocked> diff --git a/unittest/Tools/LinearLayoutTest.cpp b/unittest/Tools/LinearLayoutTest.cpp index 897172fd6d34..d6b94e83f012 100644 --- a/unittest/Tools/LinearLayoutTest.cpp +++ b/unittest/Tools/LinearLayoutTest.cpp @@ -410,26 +410,6 @@ TEST_F(LinearLayoutTest, InvertAndCompose_NonInjective) { EXPECT_EQ(composition.compose(l2), l1); } -TEST_F(LinearLayoutTest, InvertAndCompose_SmallerResult) { - // The domain of l2 is [0,16), but the codomain of the result is only [0,8), - // because there's no value v in the codomain of l1 such that l2^-1(v) >= 8. - LinearLayout l1({{S("in1"), {{1}, {2}, {4}}}}, {S("out")}); - LinearLayout l2({{S("in2"), {{4}, {1}, {2}, {8}}}}, {S("out")}); - // Pseudo-inverse of l2 is - // - // out(1) = 2 - // out(2) = 4 - // out(4) = 1 - // out(8) = 8 - // - // Composing with l1 gives back l2^-1 without the out(8) entry. - LinearLayout composition = l1.invertAndCompose(l2); - EXPECT_EQ(composition, - LinearLayout({{S("in1"), {{2}, {4}, {1}}}}, {{S("in2"), 16}}, - /*requireSurjective=*/false)); - EXPECT_TRUE(composition.compose(l2).equalIgnoringOutDimSizes(l1)); -} - TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastedInDim) { LinearLayout l1({{S("in1"), {{2}, {1}, {4}}}, {S("in2"), {{0}}}}, {S("out")}); LinearLayout l2({{S("in"), {{4}, {1}, {2}}}}, {S("out")}); @@ -514,8 +494,10 @@ TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastedDims) { LinearLayout l1({{S("in1"), {{1}, {2}, {4}}}, {S("in2"), {{0}}}}, {S("out")}); LinearLayout l2({{S("in3"), {{1}, {2}, {4}}}, {S("in4"), {{0}}}}, {S("out")}); LinearLayout c = l1.invertAndCompose(l2); - EXPECT_EQ(c, LinearLayout::identity1D(8, S("in1"), S("in3")) * - LinearLayout::identity1D(2, S("in2"), S("in4"))); + EXPECT_EQ(c, LinearLayout( + {{S("in1"), {{1, 0}, {2, 0}, {4, 0}}}, {S("in2"), {{0, 0}}}}, + {{S("in3"), 8}, {S("in4"), 2}}, + /*requireSurjective=*/false)); EXPECT_EQ(c.compose(l2), l1.transposeOuts(llvm::to_vector(l2.getOutDimNames()))); } @@ -525,8 +507,9 @@ TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastedDims2) { LinearLayout b({{S("in3"), {{2}, {1}}}, {S("in4"), {{0}}}}, {S("out")}); LinearLayout c = a.invertAndCompose(b); EXPECT_EQ(c, - LinearLayout({{S("in1"), {{2, 0}, {1, 0}}}, {S("in2"), {{0, 1}}}}, - {S("in3"), S("in4")})); + LinearLayout({{S("in1"), {{2, 0}, {1, 0}}}, {S("in2"), {{0, 0}}}}, + {{S("in3"), 4}, {S("in4"), 2}}, + /*requireSurjective=*/false)); EXPECT_EQ(c.compose(b), a.transposeOuts(llvm::to_vector(b.getOutDimNames()))); } @@ -746,40 +729,6 @@ TEST_F(LinearLayoutTest, QuotientIdentityMultipleDimensions) { ASSERT_TRUE(quotientLayout.has_value()); ASSERT_TRUE(quotientLayout->quotient({S("dim2")}).has_value()); } - -TEST_F(LinearLayoutTest, Resize) { - auto init = LinearLayout( - { - {S("in0"), {{0, 1}, {0, 2}}}, - {S("in1"), {{1, 0}, {2, 0}}}, - {S("in2"), {}}, - }, - {S("dim0"), S("dim1")}); - EXPECT_EQ(init.resize(S("in0"), 8), - LinearLayout( - { - {S("in0"), {{0, 1}, {0, 2}, {0, 0}}}, - {S("in1"), {{1, 0}, {2, 0}}}, - {S("in2"), {}}, - }, - {S("dim0"), S("dim1")})); - EXPECT_EQ(init.resize(S("in0"), 4), LinearLayout( - { - {S("in0"), {{0, 1}, {0, 2}}}, - {S("in1"), {{1, 0}, {2, 0}}}, - {S("in2"), {}}, - }, - {S("dim0"), S("dim1")})); - EXPECT_EQ(init.resize(S("in1"), 8), - LinearLayout( - { - {S("in0"), {{0, 1}, {0, 2}}}, - {S("in1"), {{1, 0}, {2, 0}, {0, 0}}}, - {S("in2"), {}}, - }, - {S("dim0"), S("dim1")})); -} - } // anonymous namespace } // namespace mlir::triton