Skip to content

Commit 1a2d481

Browse files
committed
[mlir][vector] Remove MatrixMultiplyOp and FlatTransposeOp from Vector dialect
This patch deletes `vector.matrix_multiply` and `vector.flat_transpose`, which are thin wrappers around the corresponding LLVM intrinsics: - `llvm.intr.matrix.multiply` - `llvm.intr.matrix.transpose` These Vector dialect ops did not provide additional semantics or abstraction beyond the LLVM intrinsics. Their removal simplifies the lowering pipeline without losing any functionality. The lowering chains: - `vector.contract` → `vector.matrix_multiply` → `llvm.intr.matrix.multiply` - `vector.transpose` → `vector.flat_transpose` → `llvm.intr.matrix.transpose` are now replaced with: - `vector.contract` → `llvm.intr.matrix.multiply` - `vector.transpose` → `llvm.intr.matrix.transpose` This was accomplished by directly replacing: - `vector::MatrixMultiplyOp` with `LLVM::MatrixMultiplyOp` - `vector::FlatTransposeOp` with `LLVM::MatrixTransposeOp` Note: This change introduces a build-time dependency from `Vector` to `LLVM`. Ideally, such dependencies should be confined to dialect conversion (`ConvertVectorToLLVM`). However, moving the lowering code there would introduce notable churn, so this patch leaves the new dependency in place for now.
1 parent 85349b4 commit 1a2d481

File tree

17 files changed

+26
-312
lines changed

17 files changed

+26
-312
lines changed

mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,6 @@
1313
namespace mlir {
1414
class LLVMTypeConverter;
1515

16-
/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
17-
/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics
18-
/// will be needed when invoking LLVM.
19-
void populateVectorToLLVMMatrixConversionPatterns(
20-
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
21-
2216
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
2317
void populateVectorToLLVMConversionPatterns(
2418
const LLVMTypeConverter &converter, RewritePatternSet &patterns,

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 0 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -2788,124 +2788,6 @@ def Vector_PrintOp :
27882788
}];
27892789
}
27902790

2791-
//===----------------------------------------------------------------------===//
2792-
// Ops used for supporting progressive lowering and conversion type changes.
2793-
// The Ops are typically not used directly by higher level dialects, but are
2794-
// used by intra-dialect rewriting rules to bring vector operations closer
2795-
// to the hardware ISA.
2796-
//===----------------------------------------------------------------------===//
2797-
2798-
/// Vector dialect matrix multiplication op that operates on flattened 1-D
2799-
/// MLIR vectors. This is the counterpart of llvm.matrix.multiply in MLIR.
2800-
/// This may seem redundant with vector.contract but it serves the purposes of
2801-
/// more progressive lowering and localized type conversion on the path:
2802-
/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
2803-
def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure,
2804-
PredOpTrait<"lhs operand and result have same element type",
2805-
TCresVTEtIsSameAsOpBase<0, 0>>,
2806-
PredOpTrait<"rhs operand and result have same element type",
2807-
TCresVTEtIsSameAsOpBase<0, 1>>]>,
2808-
Arguments<(
2809-
// TODO: tighten vector element types that make sense.
2810-
ins FixedVectorOfRankAndType<[1],
2811-
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$lhs,
2812-
FixedVectorOfRankAndType<[1],
2813-
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$rhs,
2814-
I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>,
2815-
Results<(
2816-
outs FixedVectorOfRankAndType<[1],
2817-
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)>
2818-
{
2819-
let summary = "Vector matrix multiplication op that operates on flattened 1-D"
2820-
" MLIR vectors";
2821-
let description = [{
2822-
This is the counterpart of llvm.matrix.multiply in MLIR. It serves the
2823-
purposes of more progressive lowering and localized type conversion.
2824-
Higher levels typically lower matrix multiplications into 'vector.contract'
2825-
operations. Subsequent rewriting rule progressively lower these operations
2826-
into 'vector.matrix_multiply' operations to bring the operations closer
2827-
to the hardware ISA.
2828-
2829-
The ‘vector.matrix_multiply’ op treats `lhs` as matrix with <lhs_rows> rows
2830-
and <lhs_columns> columns, `rhs` as matrix with <lhs_columns> rows and
2831-
<rhs_columns> and multiplies them. The result matrix is returned embedded in
2832-
the result vector.
2833-
2834-
Note, the corresponding LLVM intrinsic, `@llvm.matrix.multiply.*`, does not
2835-
support scalable vectors. Hence, this Op is only available for fixed-width
2836-
vectors. Also see:
2837-
2838-
http://llvm.org/docs/LangRef.html#llvm-matrix-multiply-intrinsic
2839-
2840-
Example:
2841-
2842-
```mlir
2843-
%C = vector.matrix_multiply %A, %B
2844-
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
2845-
(vector<64xf64>, vector<48xf64>) -> vector<12xf64>
2846-
```
2847-
}];
2848-
let builders = [
2849-
OpBuilder<(ins "Value":$lhs, "Value":$rhs, "unsigned":$lhsRows,
2850-
"unsigned":$lhsColumns, "unsigned":$rhsColumns),
2851-
[{
2852-
$_state.addOperands({lhs, rhs});
2853-
$_state.addAttribute("lhs_rows",$_builder.getI32IntegerAttr(lhsRows));
2854-
$_state.addAttribute("lhs_columns",$_builder.getI32IntegerAttr(lhsColumns));
2855-
$_state.addAttribute("rhs_columns",$_builder.getI32IntegerAttr(rhsColumns));
2856-
$_state.addTypes(VectorType::get(lhsRows * rhsColumns,
2857-
::llvm::cast<VectorType>(lhs.getType()).getElementType()));
2858-
}]>,
2859-
];
2860-
let assemblyFormat = "$lhs `,` $rhs attr-dict "
2861-
"`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
2862-
}
2863-
2864-
/// Vector dialect matrix transposition op that operates on flattened 1-D
2865-
/// MLIR vectors. This is the counterpart of llvm.matrix.transpose in MLIR.
2866-
/// This may seem redundant with vector.transpose but it serves the purposes of
2867-
/// more progressive lowering and localized type conversion on the path:
2868-
/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
2869-
def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [Pure,
2870-
PredOpTrait<"source operand and result have same element type",
2871-
TCresVTEtIsSameAsOpBase<0, 0>>]>,
2872-
Arguments<(
2873-
// TODO: tighten vector element types that make sense.
2874-
ins FixedVectorOfRankAndType<[1],
2875-
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$matrix,
2876-
I32Attr:$rows, I32Attr:$columns)>,
2877-
Results<(
2878-
outs FixedVectorOfRankAndType<[1],
2879-
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)> {
2880-
let summary = "Vector matrix transposition on flattened 1-D MLIR vectors";
2881-
let description = [{
2882-
This is the counterpart of llvm.matrix.transpose in MLIR. It serves
2883-
the purposes of more progressive lowering and localized type conversion.
2884-
Higher levels typically lower matrix transpositions into 'vector.transpose'
2885-
operations. Subsequent rewriting rule progressively lower these operations
2886-
into 'vector.flat_transpose' operations to bring the operations closer
2887-
to the hardware ISA.
2888-
2889-
The `vector.flat_transpose` op treats the 1-D input `matrix` as
2890-
a 2-D matrix with <rows> rows and <columns> columns, and returns the
2891-
transposed matrix in flattened form in 'res'.
2892-
2893-
Note, the corresponding LLVM intrinsic, `@llvm.matrix.transpose.*`, does not
2894-
support scalable vectors. Hence, this Op is only available for fixed-width
2895-
vectors. Also see:
2896-
2897-
http://llvm.org/docs/LangRef.html#llvm-matrix-transpose-intrinsic
2898-
2899-
Example:
2900-
2901-
```mlir
2902-
%1 = vector.flat_transpose %0 {columns = 4 : i32, rows = 4 : i32}
2903-
: vector<16xf32> -> vector<16xf32>
2904-
```
2905-
}];
2906-
let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)";
2907-
}
2908-
29092791
//===----------------------------------------------------------------------===//
29102792
// SplatOp
29112793
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -184,41 +184,6 @@ class VectorBitCastOpConversion
184184
}
185185
};
186186

187-
/// Conversion pattern for a vector.matrix_multiply.
188-
/// This is lowered directly to the proper llvm.intr.matrix.multiply.
189-
class VectorMatmulOpConversion
190-
: public ConvertOpToLLVMPattern<vector::MatmulOp> {
191-
public:
192-
using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
193-
194-
LogicalResult
195-
matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
196-
ConversionPatternRewriter &rewriter) const override {
197-
rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
198-
matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
199-
adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
200-
matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
201-
return success();
202-
}
203-
};
204-
205-
/// Conversion pattern for a vector.flat_transpose.
206-
/// This is lowered directly to the proper llvm.intr.matrix.transpose.
207-
class VectorFlatTransposeOpConversion
208-
: public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
209-
public:
210-
using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
211-
212-
LogicalResult
213-
matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
214-
ConversionPatternRewriter &rewriter) const override {
215-
rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
216-
transOp, typeConverter->convertType(transOp.getRes().getType()),
217-
adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
218-
return success();
219-
}
220-
};
221-
222187
/// Overloaded utility that replaces a vector.load, vector.store,
223188
/// vector.maskedload and vector.maskedstore with their respective LLVM
224189
/// couterparts.
@@ -2071,12 +2036,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
20712036
converter);
20722037
}
20732038

2074-
void mlir::populateVectorToLLVMMatrixConversionPatterns(
2075-
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
2076-
patterns.add<VectorMatmulOpConversion>(converter);
2077-
patterns.add<VectorFlatTransposeOpConversion>(converter);
2078-
}
2079-
20802039
namespace {
20812040
struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
20822041
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
9696
LLVMTypeConverter converter(&getContext(), options);
9797
RewritePatternSet patterns(&getContext());
9898
populateVectorTransferLoweringPatterns(patterns);
99-
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
10099
populateVectorToLLVMConversionPatterns(
101100
converter, patterns, reassociateFPReductions, force32BitVectorIndices,
102101
useVectorAlignment);
103-
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
104102

105103
// Architecture specific augmentations.
106104
LLVMConversionTarget target(getContext());

mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
118118
return converter.isLegal(op);
119119
});
120120
// Manually mark arithmetic-performing vector instructions.
121-
target.addDynamicallyLegalOp<
122-
vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
123-
vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
121+
target.addDynamicallyLegalOp<vector::ContractionOp, vector::ReductionOp,
122+
vector::MultiDimReductionOp, vector::FMAOp,
123+
vector::OuterProductOp, vector::ScanOp>(
124124
[&](Operation *op) { return converter.isLegal(op); });
125125
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
126126
arith::ConstantOp, vector::SplatOp>();

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
5050
MLIRTensorDialect
5151
MLIRTransforms
5252
MLIRVectorDialect
53+
MLIRLLVMDialect
5354
MLIRVectorInterfaces
5455
MLIRVectorUtils
5556
)

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,12 +1271,11 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
12711271
/// %mtb = maybe_transpose
12721272
/// %flattened_a = vector.shape_cast %mta
12731273
/// %flattened_b = vector.shape_cast %mtb
1274-
/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
1274+
/// %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b
12751275
/// %mtd = vector.shape_cast %flattened_d
12761276
/// %d = maybe_untranspose %mtd
12771277
/// %e = add %c, %d
12781278
/// ```
1279-
/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
12801279
//
12811280
/// This only kicks in when vectorContractLowering is set to `Matmul`.
12821281
/// vector.transpose operations are inserted if the vector.contract op is not a
@@ -1353,8 +1352,12 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
13531352
VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
13541353
rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
13551354

1356-
Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1357-
rhsColumns);
1355+
Value mul = rew.create<LLVM::MatrixMultiplyOp>(
1356+
loc,
1357+
VectorType::get(lhsRows * rhsColumns,
1358+
cast<VectorType>(lhs.getType()).getElementType()),
1359+
lhs, rhs, lhsRows, lhsColumns, rhsColumns);
1360+
13581361
mul = rew.create<vector::ShapeCastOp>(
13591362
loc,
13601363
VectorType::get({lhsRows, rhsColumns},

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
337337
rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
338338
auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
339339
auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
340-
Value trans = rewriter.create<vector::FlatTransposeOp>(
340+
Value trans = rewriter.create<LLVM::MatrixTransposeOp>(
341341
loc, flattenedType, matrix, rows, columns);
342342
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
343343
return success();

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,36 +1424,6 @@ func.func @fma_scalable(%vec_1d: vector<[8]xf32>, %vec_2d: vector<2x[4]xf32>, %v
14241424

14251425
return %0, %1, %2: vector<[8]xf32>, vector<2x[4]xf32>, vector<1x1x[1]xf32>
14261426
}
1427-
// -----
1428-
1429-
//===----------------------------------------------------------------------===//
1430-
// vector.matrix_multiply
1431-
//===----------------------------------------------------------------------===//
1432-
1433-
// 4x16 16x3 4x3
1434-
func.func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> {
1435-
%C = vector.matrix_multiply %A, %B
1436-
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
1437-
(vector<64xf64>, vector<48xf64>) -> vector<12xf64>
1438-
return %C: vector<12xf64>
1439-
}
1440-
// CHECK-LABEL: @matrix_ops
1441-
// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
1442-
// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
1443-
// CHECK-SAME: } : (vector<64xf64>, vector<48xf64>) -> vector<12xf64>
1444-
1445-
// -----
1446-
1447-
func.func @matrix_ops_index(%A: vector<64xindex>, %B: vector<48xindex>) -> vector<12xindex> {
1448-
%C = vector.matrix_multiply %A, %B
1449-
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
1450-
(vector<64xindex>, vector<48xindex>) -> vector<12xindex>
1451-
return %C: vector<12xindex>
1452-
}
1453-
// CHECK-LABEL: @matrix_ops_index
1454-
// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
1455-
// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
1456-
// CHECK-SAME: } : (vector<64xi64>, vector<48xi64>) -> vector<12xi64>
14571427

14581428
// -----
14591429

@@ -1602,56 +1572,6 @@ func.func @create_mask_1d_scalable(%num_elems : index) -> vector<[4]xi1> {
16021572

16031573
// -----
16041574

1605-
//===----------------------------------------------------------------------===//
1606-
// vector.flat_transpose
1607-
//===----------------------------------------------------------------------===//
1608-
1609-
func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
1610-
%0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
1611-
: vector<16xf32> -> vector<16xf32>
1612-
return %0 : vector<16xf32>
1613-
}
1614-
1615-
// CHECK-LABEL: func @flat_transpose
1616-
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
1617-
// CHECK: %[[T:.*]] = llvm.intr.matrix.transpose %[[A]]
1618-
// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} :
1619-
// CHECK-SAME: vector<16xf32> into vector<16xf32>
1620-
// CHECK: return %[[T]] : vector<16xf32>
1621-
1622-
// -----
1623-
1624-
func.func @flat_transpose_index(%arg0: vector<16xindex>) -> vector<16xindex> {
1625-
%0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
1626-
: vector<16xindex> -> vector<16xindex>
1627-
return %0 : vector<16xindex>
1628-
}
1629-
// CHECK-LABEL: func @flat_transpose_index
1630-
// CHECK-SAME: %[[A:.*]]: vector<16xindex>
1631-
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<16xindex> to vector<16xi64>
1632-
// CHECK: %[[T1:.*]] = llvm.intr.matrix.transpose %[[T0]]
1633-
// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} :
1634-
// CHECK-SAME: vector<16xi64> into vector<16xi64>
1635-
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<16xi64> to vector<16xindex>
1636-
// CHECK: return %[[T2]] : vector<16xindex>
1637-
1638-
// -----
1639-
1640-
func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
1641-
%0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
1642-
: vector<16xf32> -> vector<16xf32>
1643-
return %0 : vector<16xf32>
1644-
}
1645-
1646-
// CHECK-LABEL: func @flat_transpose
1647-
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
1648-
// CHECK: %[[T:.*]] = llvm.intr.matrix.transpose %[[A]]
1649-
// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} :
1650-
// CHECK-SAME: vector<16xf32> into vector<16xf32>
1651-
// CHECK: return %[[T]] : vector<16xf32>
1652-
1653-
// -----
1654-
16551575
//===----------------------------------------------------------------------===//
16561576
// vector.gather
16571577
//

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,13 +1321,6 @@ func.func @transpose_dim_size_mismatch(%arg0: vector<11x7x3x2xi32>) {
13211321

13221322
// -----
13231323

1324-
func.func @flat_transpose_type_mismatch(%arg0: vector<16xf32>) {
1325-
// expected-error@+1 {{'vector.flat_transpose' op failed to verify that source operand and result have same element type}}
1326-
%0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } : vector<16xf32> -> vector<16xf64>
1327-
}
1328-
1329-
// -----
1330-
13311324
func.func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>>) {
13321325
// expected-error@+1 {{expects operand to be a memref with identity layout}}
13331326
%0 = vector.type_cast %arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>> to memref<vector<4x3xf32>>
@@ -1939,26 +1932,6 @@ func.func @invalid_step_2d() {
19391932

19401933
// -----
19411934

1942-
func.func @matrix_multiply_scalable(%a: vector<[4]xf64>, %b: vector<4xf64>) {
1943-
// expected-error @+1 {{'vector.matrix_multiply' op operand #0 must be fixed-length vector of signless integer or signed integer or index or floating-point values of ranks 1, but got 'vector<[4]xf64>'}}
1944-
%c = vector.matrix_multiply %a, %b {
1945-
lhs_rows = 2: i32,
1946-
lhs_columns = 2: i32 ,
1947-
rhs_columns = 2: i32 }
1948-
: (vector<[4]xf64>, vector<4xf64>) -> vector<4xf64>
1949-
1950-
return
1951-
}
1952-
1953-
// -----
1954-
1955-
func.func @flat_transpose_scalable(%arg0: vector<[16]xf32>) -> vector<[16]xf32> {
1956-
// expected-error @+1 {{'vector.flat_transpose' op operand #0 must be fixed-length vector of signless integer or signed integer or index or floating-point values of ranks 1, but got 'vector<[16]xf32>'}}
1957-
%0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
1958-
: vector<[16]xf32> -> vector<[16]xf32>
1959-
return %0 : vector<[16]xf32>
1960-
}
1961-
19621935
//===----------------------------------------------------------------------===//
19631936
// vector.splat
19641937
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)