-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][tosa] Allow shift operand of tosa::MulOp as non-constant #155197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-tosa Author: None (ShivaChen) ChangesThe shift operand of tosa::MulOp could be non-constant when the dynamic extension enabled. Given that checkConstantOperandMul could check the shift operand according to the extension, we might able to relax the checking in TosaToLinalg. Commutative of MulOp might need to be removed to avoid shift operand been reordered with other operands when the shift operand is non-constant. Full diff: https://github.com/llvm/llvm-project/pull/155197.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 20889558be314..eed428da99192 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -983,7 +983,6 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
def Tosa_MulOp : Tosa_Op<"mul", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
- Commutative,
Pure]> {
let summary = "Multiplication operator.";
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0e3de067736c5..a02d6c97aa5d8 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -126,12 +126,12 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::MulOp>(op)) {
auto shiftVal = cast<tosa::MulOp>(op).getShift();
DenseElementsAttr shiftElem;
- if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
- (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
- return nullptr;
- }
-
- int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+ bool shiftIsConstant = true;
+ int32_t shift = 0;
+ if (matchPattern(shiftVal, m_Constant(&shiftElem)))
+ shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+ else
+ shiftIsConstant = false;
if (isa<FloatType>(elementTy)) {
if (shift != 0) {
@@ -147,23 +147,24 @@ static Value createLinalgBodyCalculationForElementwiseOp(
Value a = args[0];
Value b = args[1];
- if (shift > 0) {
- auto shiftConst =
- arith::ConstantIntOp::create(rewriter, loc, shift, /*bitwidth=*/8);
+ if (shift > 0 || !shiftIsConstant) {
+ Value shiftConst;
+ if (shiftIsConstant)
+ shiftConst =
+ rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
+
if (!a.getType().isInteger(32))
a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
if (!b.getType().isInteger(32))
b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
+ auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
auto result = tosa::ApplyScaleOp::create(
- rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
+ rewriter, loc, rewriter.getI32Type(), a, b, shiftAmount,
rewriter.getStringAttr("SINGLE_ROUND"));
- if (elementTy.isInteger(32))
- return result;
-
- return arith::TruncIOp::create(rewriter, loc, elementTy, result);
+ return result;
}
int aWidth = a.getType().getIntOrFloatBitWidth();
@@ -909,6 +910,20 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
return operand;
}
+static bool hasDynamicDimensions(ValueRange operands) {
+ for (auto operand : operands) {
+ auto rankedTensorType = cast_or_null<RankedTensorType>(operand.getType());
+ if (!rankedTensorType)
+ continue;
+ int64_t rank = rankedTensorType.getRank();
+ for (auto dim : llvm::seq<int64_t>(0, rank)) {
+ if (rankedTensorType.isDynamicDim(dim))
+ return true;
+ }
+ }
+ return false;
+}
+
static SmallVector<Value>
broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, ValueRange operands,
@@ -918,6 +933,9 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
if (operands.size() == 1)
return operands;
+ if (!hasDynamicDimensions(operands))
+ return operands;
+
// Broadcast dynamic dimensions operand by operand
return llvm::map_to_vector(operands, [&](Value operand) {
return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
@@ -990,8 +1008,14 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
static ValueRange getBroadcastableOperands(Operation *operation,
ValueRange operands) {
// Shift cannot broadcast
- if (isa<tosa::MulOp>(operation))
- return operands.take_front(2);
+ if (isa<tosa::MulOp>(operation)) {
+ DenseElementsAttr shiftElems;
+ // Shift cannot broadcast when it is constant
+ if (matchPattern(operation->getOperand(2), m_Constant(&shiftElems)))
+ return operands.take_front(2);
+ else
+ return operands.take_front(3);
+ }
// Input1_zp and output_zp cannot broadcast
if (isa<tosa::NegateOp>(operation))
return operands.take_front(1);
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 69d8471df8032..d00846a4c3e02 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -73,11 +73,3 @@ func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>)
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
-
-// -----
-
-func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
- // expected-error@+1 {{failed to legalize operation 'tosa.mul'}}
- %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
- return %0 : tensor<2x3xi32>
-}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index fb912e49ff920..aee0caa91043d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -2471,3 +2471,14 @@ func.func @test_0d_input(%arg0: tensor<i32>) -> () {
return
}
+
+// -----
+
+// CHECK-LABEL: @mul_no_const_shift
+func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
+ // CHECK: linalg.generic
+ // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i8, %[[OUT:.*]]: i32):
+ // CHECK: tosa.apply_scale %[[ARG0]], %[[ARG1]], %[[ARG2]]
+ %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
+ return %0 : tensor<2x3xi32>
+}
|
|
@llvm/pr-subscribers-mlir Author: None (ShivaChen) ChangesThe shift operand of tosa::MulOp could be non-constant when the dynamic extension enabled. Given that checkConstantOperandMul could check the shift operand according to the extension, we might able to relax the checking in TosaToLinalg. Commutative of MulOp might need to be removed to avoid shift operand been reordered with other operands when the shift operand is non-constant. Full diff: https://github.com/llvm/llvm-project/pull/155197.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 20889558be314..eed428da99192 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -983,7 +983,6 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
def Tosa_MulOp : Tosa_Op<"mul", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
- Commutative,
Pure]> {
let summary = "Multiplication operator.";
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0e3de067736c5..a02d6c97aa5d8 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -126,12 +126,12 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::MulOp>(op)) {
auto shiftVal = cast<tosa::MulOp>(op).getShift();
DenseElementsAttr shiftElem;
- if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
- (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
- return nullptr;
- }
-
- int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+ bool shiftIsConstant = true;
+ int32_t shift = 0;
+ if (matchPattern(shiftVal, m_Constant(&shiftElem)))
+ shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+ else
+ shiftIsConstant = false;
if (isa<FloatType>(elementTy)) {
if (shift != 0) {
@@ -147,23 +147,24 @@ static Value createLinalgBodyCalculationForElementwiseOp(
Value a = args[0];
Value b = args[1];
- if (shift > 0) {
- auto shiftConst =
- arith::ConstantIntOp::create(rewriter, loc, shift, /*bitwidth=*/8);
+ if (shift > 0 || !shiftIsConstant) {
+ Value shiftConst;
+ if (shiftIsConstant)
+ shiftConst =
+ rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
+
if (!a.getType().isInteger(32))
a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
if (!b.getType().isInteger(32))
b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
+ auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
auto result = tosa::ApplyScaleOp::create(
- rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
+ rewriter, loc, rewriter.getI32Type(), a, b, shiftAmount,
rewriter.getStringAttr("SINGLE_ROUND"));
- if (elementTy.isInteger(32))
- return result;
-
- return arith::TruncIOp::create(rewriter, loc, elementTy, result);
+ return result;
}
int aWidth = a.getType().getIntOrFloatBitWidth();
@@ -909,6 +910,20 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
return operand;
}
+static bool hasDynamicDimensions(ValueRange operands) {
+ for (auto operand : operands) {
+ auto rankedTensorType = cast_or_null<RankedTensorType>(operand.getType());
+ if (!rankedTensorType)
+ continue;
+ int64_t rank = rankedTensorType.getRank();
+ for (auto dim : llvm::seq<int64_t>(0, rank)) {
+ if (rankedTensorType.isDynamicDim(dim))
+ return true;
+ }
+ }
+ return false;
+}
+
static SmallVector<Value>
broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, ValueRange operands,
@@ -918,6 +933,9 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
if (operands.size() == 1)
return operands;
+ if (!hasDynamicDimensions(operands))
+ return operands;
+
// Broadcast dynamic dimensions operand by operand
return llvm::map_to_vector(operands, [&](Value operand) {
return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
@@ -990,8 +1008,14 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
static ValueRange getBroadcastableOperands(Operation *operation,
ValueRange operands) {
// Shift cannot broadcast
- if (isa<tosa::MulOp>(operation))
- return operands.take_front(2);
+ if (isa<tosa::MulOp>(operation)) {
+ DenseElementsAttr shiftElems;
+ // Shift cannot broadcast when it is constant
+ if (matchPattern(operation->getOperand(2), m_Constant(&shiftElems)))
+ return operands.take_front(2);
+ else
+ return operands.take_front(3);
+ }
// Input1_zp and output_zp cannot broadcast
if (isa<tosa::NegateOp>(operation))
return operands.take_front(1);
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 69d8471df8032..d00846a4c3e02 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -73,11 +73,3 @@ func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>)
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
-
-// -----
-
-func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
- // expected-error@+1 {{failed to legalize operation 'tosa.mul'}}
- %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
- return %0 : tensor<2x3xi32>
-}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index fb912e49ff920..aee0caa91043d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -2471,3 +2471,14 @@ func.func @test_0d_input(%arg0: tensor<i32>) -> () {
return
}
+
+// -----
+
+// CHECK-LABEL: @mul_no_const_shift
+func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
+ // CHECK: linalg.generic
+ // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i8, %[[OUT:.*]]: i32):
+ // CHECK: tosa.apply_scale %[[ARG0]], %[[ARG1]], %[[ARG2]]
+ %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
+ return %0 : tensor<2x3xi32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, thanks !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think we could just use something like
const auto tType = dyn_cast<RankedTensorType>(op.getType());
if (tType && !tType.hasStaticShape())
here to avoid declaring hasDynamicDimensions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks nicer, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this might be better removed in a separate PR incase it breaks any assumptions (this way it can be cleanly reverted/bisected if there are any problems). e.g. I suspect when shift is constant, foldCommutative works correctly - it's possible there are patterns based on this assumption
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to be reasonable consideration and easier to catch issue in the future. I will add back Commutative. Thanks for catching this.
The shift operand of tosa::MulOp could be non-constant when the dynamic extension enabled. Given that checkConstantOperandMul could check the shift operand according to the extension, we might able to relax the checking in TosaToLinalg. Commutative of MulOp might need to be removed to avoid shift operand been reordered with other operands when the shift operand is non-constant.
72b273c to
76cbfe5
Compare
Thanks for review :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ShivaChen!
…oject into bugprone-method-hiding * 'bugprone-method-hiding' of github.com:t-a-james/llvm-project: (230 commits) [SimplifyCFG] Move token type check into canReplaceOperandWithVariable() [ADT] Fix signed integer overflow (llvm#155826) [Offload] Update LIBOMPTARGET_INFO text for `attach` map-type. (llvm#155509) [CMake][AIX] Enable CMP0182: Create shared library archives by default (llvm#155686) AMDGPU: Add tests for atomics with AGPR operands (llvm#155820) [AArch64] Split zero cycle zeoring per register class (llvm#154561) [gn build] Port fa883e1 [mlir][tosa] Allow shift operand of tosa::MulOp as non-constant (llvm#155197) [AArch64][NFC] Add MCInstrAnalysis unittests (llvm#155609) [Offload][OpenMP] Tests require libc on GPU for printf (llvm#155785) AMDGPU: Add missing verifier tests for load/store AGPR case (llvm#155815) [lldb-mcp] Fix building for Windows Revert "[lldb] Correct a usage after a rename was merged. (llvm#155720)" Revert "[lldb] NFC Moving mcp::Transport into its own file. (llvm#155711)" [lldb][test] Run ranges::ref_vew test only for libc++ (llvm#155813) [SCCP][FuncSpec] Poison unreachable constant global variable user (llvm#155753) [LoongArch] Lowering v32i8 vector mask generation to `VMSKLTZ` (llvm#149953) [flang][docs][NFC] Remove stray backtick (llvm#154974) [MLIR] Apply clang-tidy fixes for misc-use-internal-linkage in LinalgOps.cpp (NFC) [MLIR] Apply clang-tidy fixes for performance-move-const-arg in VariantValue.cpp (NFC) ...
The shift operand of tosa::MulOp could be non-constant when the dynamic extension enabled. Given that checkConstantOperandMul could check the shift operand according to the extension, we might able to relax the checking in TosaToLinalg.
Relative discussion: https://discourse.llvm.org/t/tosa-ext-dynamic-clearification-needed/87478?u=r2333333.