@@ -1612,7 +1612,27 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
16121612 }
16131613};
16141614
1615- // / For vectors with either leading or trailing unit dim, replaces:
1615+ // Scalable unit dimensions are not supported. Folding such dimensions would
1616+ // require "shifting" the scalable flag onto some other fixed-width dim (e.g.
1617+ // vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
1618+ // future.
1619+ static VectorType dropNonScalableUnitDimFromType (VectorType inVecTy) {
1620+ auto inVecShape = inVecTy.getShape ();
1621+ SmallVector<int64_t > newShape;
1622+ SmallVector<bool > newScalableDims;
1623+ for (auto [dim, isScalable] :
1624+ llvm::zip_equal (inVecShape, inVecTy.getScalableDims ())) {
1625+ if (dim == 1 && !isScalable)
1626+ continue ;
1627+
1628+ newShape.push_back (dim);
1629+ newScalableDims.push_back (isScalable);
1630+ }
1631+
1632+ return VectorType::get (newShape, inVecTy.getElementType (), newScalableDims);
1633+ }
1634+
1635+ // / For vectors with at least an unit dim, replaces:
16161636// / elementwise(a, b)
16171637// / with:
16181638// / sc_a = shape_cast(a)
@@ -1624,20 +1644,16 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
16241644// / required to be rank > 1.
16251645// /
16261646// / Ex:
1627- // / ```
16281647// / %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
16291648// / %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1630- // / ```
16311649// /
16321650// / gets converted to:
16331651// /
1634- // / ```
16351652// / %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
16361653// / %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
16371654// / %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
16381655// / %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
16391656// / %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1640- // / ```
16411657// /
16421658// / Patterns for folding shape_casts should instantly eliminate `%cast_new` and
16431659// / `%cast`.
@@ -1657,42 +1673,29 @@ struct DropUnitDimFromElementwiseOps final
16571673 // guaranteed to have identical shapes (with some exceptions such as
16581674 // `arith.select`) and it suffices to only check one of them.
16591675 auto sourceVectorType = dyn_cast<VectorType>(op->getOperand (0 ).getType ());
1660- if (!sourceVectorType)
1661- return failure ();
1662- if (sourceVectorType.getRank () < 2 )
1663- return failure ();
1664-
1665- bool hasTrailingDimUnitFixed =
1666- ((sourceVectorType.getShape ().back () == 1 ) &&
1667- (!sourceVectorType.getScalableDims ().back ()));
1668- bool hasLeadingDimUnitFixed =
1669- ((sourceVectorType.getShape ().front () == 1 ) &&
1670- (!sourceVectorType.getScalableDims ().front ()));
1671- if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
1676+ if (!sourceVectorType || sourceVectorType.getRank () < 2 )
16721677 return failure ();
16731678
1674- // Drop leading/trailing unit dim by applying vector.shape_cast to all
1675- // operands
1676- int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank () - 1 ;
1677-
16781679 SmallVector<Value> newOperands;
16791680 auto loc = op->getLoc ();
16801681 for (auto operand : op->getOperands ()) {
16811682 auto opVectorType = cast<VectorType>(operand.getType ());
1682- VectorType newVType = VectorType::Builder (opVectorType).dropDim (dim);
1683+ auto newVType = dropNonScalableUnitDimFromType (opVectorType);
1684+ if (newVType == opVectorType)
1685+ return rewriter.notifyMatchFailure (op, " No unit dimension to remove." );
1686+
16831687 auto opSC = rewriter.create <vector::ShapeCastOp>(loc, newVType, operand);
16841688 newOperands.push_back (opSC);
16851689 }
16861690
16871691 VectorType newResultVectorType =
1688- VectorType::Builder (resultVectorType). dropDim (dim );
1689- // Create an updated elementwise Op without leading/trailing unit dim
1692+ dropNonScalableUnitDimFromType (resultVectorType);
1693+ // Create an updated elementwise Op without unit dim.
16901694 Operation *elementwiseOp =
16911695 rewriter.create (loc, op->getName ().getIdentifier (), newOperands,
16921696 newResultVectorType, op->getAttrs ());
16931697
1694- // Restore the leading/trailing unit dim by applying vector.shape_cast
1695- // to the result
1698+ // Restore the unit dim by applying vector.shape_cast to the result.
16961699 rewriter.replaceOpWithNewOp <ShapeCastOp>(op, resultVectorType,
16971700 elementwiseOp->getResult (0 ));
16981701
0 commit comments