Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2440,12 +2440,11 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
}];

let arguments = (ins TransformHandleTypeInterface:$target,
Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
$static_vector_sizes,
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
$scalable_sizes);
Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
OptionalAttr<UnitAttr>:$assume_dynamic_dims_match_vec_sizes,
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);

let results = (outs);

Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,8 @@ FailureOr<VectorizationResult>
vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
bool assumeDynamicDimsMatchVecSizes = false);

/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3921,7 +3921,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
}
FailureOr<VectorizationResult> vectorResults =
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
getVectorizeNdExtract().value_or(false));
getVectorizeNdExtract().value_or(false), false,
getAssumeDynamicDimsMatchVecSizes().value_or(false));
if (failed(vectorResults)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
Expand Down
53 changes: 41 additions & 12 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some changes on comments are not relevant, can you revert them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how that crept it. Fixed in this commit.

Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ struct VectorizationState {
/// canonical vector shape for vectorization.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims);
ArrayRef<bool> inputScalableVecDims,
bool assumeDynamicDimsMatchVecSizes = false);

/// Returns the canonical vector shape used to vectorize the iteration space.
ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
Expand Down Expand Up @@ -331,6 +332,14 @@ struct VectorizationState {
/// Global vectorization guard for the incoming rewriter. It's initialized
/// when the vectorization state is initialized.
OpBuilder::InsertionGuard rewriterGuard;

/// Do all dynamic dims match the corresponding vector sizes?
///
/// When a dynamic tensor/memref dimension matches the corresponding vector
/// dimension, masking can be safely skipped, despite the presence of dynamic
/// shapes. Use this flag with care and only for cases where you are
/// confident the assumption holds.
bool assumeDynamicDimsMatchVecSizes = false;
};

LogicalResult
Expand Down Expand Up @@ -367,10 +376,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
/// Initializes the vectorization state, including the computation of the
/// canonical vector shape for vectorization.
// TODO: Move this to the constructor when we can remove the failure cases.
LogicalResult
VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims) {
LogicalResult VectorizationState::initState(RewriterBase &rewriter,
LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
bool assumeDimsMatchVec) {
assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
// Initialize the insertion point.
rewriter.setInsertionPoint(linalgOp);

Expand Down Expand Up @@ -470,6 +481,23 @@ Value VectorizationState::getOrCreateMaskFor(
return Value();
}

if (assumeDynamicDimsMatchVecSizes) {
// While for _dynamic_ dim sizes we can _assume_ that the corresponding
// vector sizes match, we still need to check the _static_ dim sizes. Only
// then we can be 100% sure that masking is not required.
if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
[](auto it) {
return std::get<0>(it) == ShapedType::kDynamic
? true
: std::get<0>(it) == std::get<1>(it);
})) {
LDBG("Dynamic + static dimensions match vector sizes, masking is not "
"required.\n");
activeMaskCache[maskingMap] = Value();
return Value();
}
}

// Permute the iteration space value sizes to compute the mask upper bounds.
SmallVector<Value> upperBounds =
applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
Expand Down Expand Up @@ -2479,7 +2507,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
isa<linalg::MatmulTransposeAOp>(op) ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if there is a particular reason why this wouldn't work for batch_mmt4d as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should. I am not enabling it just yet to keep this PR relatively small. Adding batch_mmt4d would mean more tests and I would rather do it separately.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, makes sense.

hasReductionIterator(linalgOp));
}

LogicalResult mlir::linalg::vectorizeOpPrecondition(
Expand Down Expand Up @@ -2535,11 +2564,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
tensor::InsertSliceOp>(op);
}

FailureOr<VectorizationResult>
mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
FailureOr<VectorizationResult> mlir::linalg::vectorize(
RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
Expand All @@ -2559,7 +2587,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
VectorizationState state(rewriter);
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
inputScalableVecDims))) {
inputScalableVecDims,
assumeDynamicDimsMatchVecSizes))) {
LDBG("Vectorization state couldn't be initialized\n");
return failure();
}
Expand Down
117 changes: 93 additions & 24 deletions mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,99 @@ module attributes {transform.with_named_sequence} {
}
}

// -----

///----------------------------------------------------------------------------------------
/// Tests for linalg.mmt4d
///----------------------------------------------------------------------------------------

func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
outs(%C_in: memref<16x16x8x8xf32>)
return
}

// CHECK-LABEL: func.func @mmt4d(
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %mmt4d : !transform.any_op
transform.yield
}
}

// -----

func.func @mmt4d_scalable(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
outs(%C_in: memref<16x16x8x?xf32>)
return
}
// CHECK-LABEL: func.func @mmt4d_scalable(
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>,
// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
// CHECK: %[[VAL_0:.*]] = arith.constant 16 : index
// CHECK: %[[VAL_1:.*]] = arith.constant 16 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 16 : index
// CHECK: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32>
// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[MASK_1:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x[4]x1xi1>
// CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
// CHECK: %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
// CHECK: %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x16x8x[4]x1xi1>
// CHECK: %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16]], %[[VAL_15]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
// CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>


module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] : !transform.any_op
transform.yield
}
}

// -----

func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
outs(%C_in: memref<16x16x8x?xf32>)
return
}
// CHECK-LABEL: func.func @mmt4d_scalable_with_assume(
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>,
// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
// CHECK-NOT: mask
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
// CHECK: %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[VAL_15:.*]] = vector.multi_reduction <add>, %[[VAL_14]], %[[VAL_13]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
// CHECK: vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_dynamic_dims_match_vec_sizes} : !transform.any_op
transform.yield
}
}

///----------------------------------------------------------------------------------------
/// Tests for other Ops
///----------------------------------------------------------------------------------------
Expand Down Expand Up @@ -1094,30 +1187,6 @@ module attributes {transform.with_named_sequence} {
}
}

// -----

func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
outs(%C_in: memref<16x16x8x8xf32>)
return
}

// CHECK-LABEL: func.func @mmt4d(
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %mmt4d : !transform.any_op
transform.yield
}
}

// -----

Expand Down