@@ -2000,13 +2000,217 @@ struct VectorScalableStepOpLowering
20002000 }
20012001};
20022002
2003+ // / Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
2004+ // / semantics to:
2005+ // / ```
2006+ // / %flattened_a = vector.shape_cast %a
2007+ // / %flattened_b = vector.shape_cast %b
2008+ // / %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
2009+ // / %d = vector.shape_cast %%flattened_d
2010+ // / %e = add %c, %d
2011+ // / ```
2012+ // / `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
2013+ //
2014+ // / This only kicks in when vectorContractLowering is set to Matmul and
2015+ // / the vector.contract op is a row-major matrix multiply.
2016+ class ContractionOpToMatmulOpLowering
2017+ : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
2018+ public:
2019+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
2020+
2021+ ContractionOpToMatmulOpLowering (
2022+ vector::VectorContractLowering vectorContractLowering,
2023+ MLIRContext *context, PatternBenefit benefit = 100 )
2024+ : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {}
2025+
2026+ FailureOr<Value>
2027+ matchAndRewriteMaskableOp (vector::ContractionOp op, MaskingOpInterface maskOp,
2028+ PatternRewriter &rewriter) const override ;
2029+ };
2030+
2031+ // / Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
2032+ // / semantics to:
2033+ // / ```
2034+ // / %mta = maybe_transpose
2035+ // / %mtb = maybe_transpose
2036+ // / %flattened_a = vector.shape_cast %mta
2037+ // / %flattened_b = vector.shape_cast %mtb
2038+ // / %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b
2039+ // / %mtd = vector.shape_cast %flattened_d
2040+ // / %d = maybe_untranspose %mtd
2041+ // / %e = add %c, %d
2042+ // / ```
2043+ //
2044+ // / This only kicks in when vectorContractLowering is set to `Matmul`.
2045+ // / vector.transpose operations are inserted if the vector.contract op is not a
2046+ // / row-major matrix multiply.
2047+ // /
2048+ // / Scalable vectors are not supported.
2049+ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp (
2050+ vector::ContractionOp op, MaskingOpInterface maskOp,
2051+ PatternRewriter &rew) const {
2052+ // TODO: Support vector.mask.
2053+ if (maskOp)
2054+ return failure ();
2055+
2056+ auto iteratorTypes = op.getIteratorTypes ().getValue ();
2057+ if (!isParallelIterator (iteratorTypes[0 ]) ||
2058+ !isParallelIterator (iteratorTypes[1 ]) ||
2059+ !isReductionIterator (iteratorTypes[2 ]))
2060+ return failure ();
2061+
2062+ Type opResType = op.getType ();
2063+ VectorType vecType = dyn_cast<VectorType>(opResType);
2064+ if (vecType && vecType.isScalable ()) {
2065+ // Note - this is sufficient to reject all cases with scalable vectors.
2066+ return failure ();
2067+ }
2068+
2069+ Type elementType = op.getLhsType ().getElementType ();
2070+ if (!elementType.isIntOrFloat ())
2071+ return failure ();
2072+
2073+ Type dstElementType = vecType ? vecType.getElementType () : opResType;
2074+ if (elementType != dstElementType)
2075+ return failure ();
2076+
2077+ // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
2078+ // Bail out if the contraction cannot be put in this form.
2079+ MLIRContext *ctx = op.getContext ();
2080+ Location loc = op.getLoc ();
2081+ AffineExpr m, n, k;
2082+ bindDims (rew.getContext (), m, n, k);
2083+ // LHS must be A(m, k) or A(k, m).
2084+ Value lhs = op.getLhs ();
2085+ auto lhsMap = op.getIndexingMapsArray ()[0 ];
2086+ if (lhsMap == AffineMap::get (3 , 0 , {k, m}, ctx))
2087+ lhs = rew.create <vector::TransposeOp>(loc, lhs, ArrayRef<int64_t >{1 , 0 });
2088+ else if (lhsMap != AffineMap::get (3 , 0 , {m, k}, ctx))
2089+ return failure ();
2090+
2091+ // RHS must be B(k, n) or B(n, k).
2092+ Value rhs = op.getRhs ();
2093+ auto rhsMap = op.getIndexingMapsArray ()[1 ];
2094+ if (rhsMap == AffineMap::get (3 , 0 , {n, k}, ctx))
2095+ rhs = rew.create <vector::TransposeOp>(loc, rhs, ArrayRef<int64_t >{1 , 0 });
2096+ else if (rhsMap != AffineMap::get (3 , 0 , {k, n}, ctx))
2097+ return failure ();
2098+
2099+ // At this point lhs and rhs are in row-major.
2100+ VectorType lhsType = cast<VectorType>(lhs.getType ());
2101+ VectorType rhsType = cast<VectorType>(rhs.getType ());
2102+ int64_t lhsRows = lhsType.getDimSize (0 );
2103+ int64_t lhsColumns = lhsType.getDimSize (1 );
2104+ int64_t rhsColumns = rhsType.getDimSize (1 );
2105+
2106+ Type flattenedLHSType =
2107+ VectorType::get (lhsType.getNumElements (), lhsType.getElementType ());
2108+ lhs = rew.create <vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
2109+
2110+ Type flattenedRHSType =
2111+ VectorType::get (rhsType.getNumElements (), rhsType.getElementType ());
2112+ rhs = rew.create <vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
2113+
2114+ Value mul = rew.create <LLVM::MatrixMultiplyOp>(
2115+ loc,
2116+ VectorType::get (lhsRows * rhsColumns,
2117+ cast<VectorType>(lhs.getType ()).getElementType ()),
2118+ lhs, rhs, lhsRows, lhsColumns, rhsColumns);
2119+
2120+ mul = rew.create <vector::ShapeCastOp>(
2121+ loc,
2122+ VectorType::get ({lhsRows, rhsColumns},
2123+ getElementTypeOrSelf (op.getAcc ().getType ())),
2124+ mul);
2125+
2126+ // ACC must be C(m, n) or C(n, m).
2127+ auto accMap = op.getIndexingMapsArray ()[2 ];
2128+ if (accMap == AffineMap::get (3 , 0 , {n, m}, ctx))
2129+ mul = rew.create <vector::TransposeOp>(loc, mul, ArrayRef<int64_t >{1 , 0 });
2130+ else if (accMap != AffineMap::get (3 , 0 , {m, n}, ctx))
2131+ llvm_unreachable (" invalid contraction semantics" );
2132+
2133+ Value res =
2134+ isa<IntegerType>(elementType)
2135+ ? static_cast <Value>(rew.create <arith::AddIOp>(loc, op.getAcc (), mul))
2136+ : static_cast <Value>(
2137+ rew.create <arith::AddFOp>(loc, op.getAcc (), mul));
2138+
2139+ return res;
2140+ }
2141+
2142+ // / Progressive lowering of TransposeOp.
2143+ // / One:
2144+ // / %x = vector.transpose %y, [1, 0]
2145+ // / is replaced by:
2146+ // / %z = arith.constant dense<0.000000e+00>
2147+ // / %0 = vector.extract %y[0, 0]
2148+ // / %1 = vector.insert %0, %z [0, 0]
2149+ // / ..
2150+ // / %x = vector.insert .., .. [.., ..]
2151+ class TransposeOpLowering : public OpRewritePattern <vector::TransposeOp> {
2152+ public:
2153+ using OpRewritePattern<TransposeOp>::OpRewritePattern;
2154+
2155+ LogicalResult matchAndRewrite (vector::TransposeOp op,
2156+ PatternRewriter &rewriter) const override {
2157+ auto loc = op.getLoc ();
2158+
2159+ Value input = op.getVector ();
2160+ VectorType inputType = op.getSourceVectorType ();
2161+ VectorType resType = op.getResultVectorType ();
2162+
2163+ if (inputType.isScalable ())
2164+ return rewriter.notifyMatchFailure (
2165+ op, " This lowering does not support scalable vectors" );
2166+
2167+ // Set up convenience transposition table.
2168+ ArrayRef<int64_t > transp = op.getPermutation ();
2169+
2170+ if (resType.getRank () != 2 || transp[0 ] != 1 || transp[1 ] != 0 ) {
2171+ return failure ();
2172+ }
2173+
2174+ Type flattenedType =
2175+ VectorType::get (resType.getNumElements (), resType.getElementType ());
2176+ auto matrix =
2177+ rewriter.create <vector::ShapeCastOp>(loc, flattenedType, input);
2178+ auto rows = rewriter.getI32IntegerAttr (resType.getShape ()[0 ]);
2179+ auto columns = rewriter.getI32IntegerAttr (resType.getShape ()[1 ]);
2180+ Value trans = rewriter.create <LLVM::MatrixTransposeOp>(
2181+ loc, flattenedType, matrix, rows, columns);
2182+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(op, resType, trans);
2183+ return success ();
2184+ }
2185+ };
2186+
20032187} // namespace
20042188
20052189void mlir::vector::populateVectorRankReducingFMAPattern (
20062190 RewritePatternSet &patterns) {
20072191 patterns.add <VectorFMAOpNDRewritePattern>(patterns.getContext ());
20082192}
20092193
2194+ // / Pattern to lower `vector.contract` to `llvm.intr.matrix.multiply`.
2195+ // /
2196+ // / Given the high benefit, this will be prioriotised over other
2197+ // / contract-lowering patterns. As such, the convert-vector-to-llvm pass will
2198+ // / only run this registration conditionally.
2199+ void mlir::vector::populateVectorContractToMatrixMultiply (
2200+ RewritePatternSet &patterns) {
2201+ patterns.add <ContractionOpToMatmulOpLowering>(patterns.getContext (), 100 );
2202+ }
2203+
2204+ // / Pattern to lower `vector.transpose` to `llvm.intr.matrix.flat_transpose`.
2205+ // /
2206+ // / Given the high benefit, this will be prioriotised over other
2207+ // / transpose-lowering patterns. As such, the convert-vector-to-llvm pass will
2208+ // / only run this registration conditionally.
2209+ void mlir::vector::populateVectorTransposeToFlatTranspose (
2210+ RewritePatternSet &patterns) {
2211+ patterns.add <TransposeOpLowering>(patterns.getContext (), 100 );
2212+ }
2213+
20102214// / Populate the given list with patterns that convert from Vector to LLVM.
20112215void mlir::populateVectorToLLVMConversionPatterns (
20122216 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
0 commit comments