Skip to content

Conversation

@ShivaChen
Copy link
Collaborator

@ShivaChen ShivaChen commented Aug 25, 2025

The shift operand of tosa::MulOp could be non-constant when the dynamic extension enabled. Given that checkConstantOperandMul could check the shift operand according to the extension, we might able to relax the checking in TosaToLinalg.

Relative discussion: https://discourse.llvm.org/t/tosa-ext-dynamic-clearification-needed/87478?u=r2333333.

@llvmbot
Copy link
Member

llvmbot commented Aug 25, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-tosa

Author: None (ShivaChen)

Changes

The shift operand of tosa::MulOp could be non-constant when the dynamic extension enabled. Given that checkConstantOperandMul could check the shift operand according to the extension, we might able to relax the checking in TosaToLinalg.

Commutative of MulOp might need to be removed to avoid shift operand been reordered with other operands when the shift operand is non-constant.


Full diff: https://github.com/llvm/llvm-project/pull/155197.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (-1)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+40-16)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (-8)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+11)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 20889558be314..eed428da99192 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -983,7 +983,6 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
 def Tosa_MulOp : Tosa_Op<"mul", [
     DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
                               ["inferReturnTypeComponents"]>,
-    Commutative,
     Pure]> {
   let summary = "Multiplication operator.";
 
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0e3de067736c5..a02d6c97aa5d8 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -126,12 +126,12 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::MulOp>(op)) {
     auto shiftVal = cast<tosa::MulOp>(op).getShift();
     DenseElementsAttr shiftElem;
-    if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
-      (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
-      return nullptr;
-    }
-
-    int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+    bool shiftIsConstant = true;
+    int32_t shift = 0;
+    if (matchPattern(shiftVal, m_Constant(&shiftElem)))
+      shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+    else
+      shiftIsConstant = false;
 
     if (isa<FloatType>(elementTy)) {
       if (shift != 0) {
@@ -147,23 +147,24 @@ static Value createLinalgBodyCalculationForElementwiseOp(
       Value a = args[0];
       Value b = args[1];
 
-      if (shift > 0) {
-        auto shiftConst =
-            arith::ConstantIntOp::create(rewriter, loc, shift, /*bitwidth=*/8);
+      if (shift > 0 || !shiftIsConstant) {
+        Value shiftConst;
+        if (shiftIsConstant)
+          shiftConst =
+              rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
+
         if (!a.getType().isInteger(32))
           a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
 
         if (!b.getType().isInteger(32))
           b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
 
+        auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
         auto result = tosa::ApplyScaleOp::create(
-            rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
+            rewriter, loc, rewriter.getI32Type(), a, b, shiftAmount,
             rewriter.getStringAttr("SINGLE_ROUND"));
 
-        if (elementTy.isInteger(32))
-          return result;
-
-        return arith::TruncIOp::create(rewriter, loc, elementTy, result);
+        return result;
       }
 
       int aWidth = a.getType().getIntOrFloatBitWidth();
@@ -909,6 +910,20 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
   return operand;
 }
 
+static bool hasDynamicDimensions(ValueRange operands) {
+  for (auto operand : operands) {
+    auto rankedTensorType = cast_or_null<RankedTensorType>(operand.getType());
+    if (!rankedTensorType)
+      continue;
+    int64_t rank = rankedTensorType.getRank();
+    for (auto dim : llvm::seq<int64_t>(0, rank)) {
+      if (rankedTensorType.isDynamicDim(dim))
+        return true;
+    }
+  }
+  return false;
+}
+
 static SmallVector<Value>
 broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
                            IndexPool &indexPool, ValueRange operands,
@@ -918,6 +933,9 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
   if (operands.size() == 1)
     return operands;
 
+  if (!hasDynamicDimensions(operands))
+    return operands;
+
   // Broadcast dynamic dimensions operand by operand
   return llvm::map_to_vector(operands, [&](Value operand) {
     return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
@@ -990,8 +1008,14 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
 static ValueRange getBroadcastableOperands(Operation *operation,
                                            ValueRange operands) {
   // Shift cannot broadcast
-  if (isa<tosa::MulOp>(operation))
-    return operands.take_front(2);
+  if (isa<tosa::MulOp>(operation)) {
+    DenseElementsAttr shiftElems;
+    // Shift cannot broadcast when it is constant
+    if (matchPattern(operation->getOperand(2), m_Constant(&shiftElems)))
+      return operands.take_front(2);
+    else
+      return operands.take_front(3);
+  }
   // Input1_zp and output_zp cannot broadcast
   if (isa<tosa::NegateOp>(operation))
     return operands.take_front(1);
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 69d8471df8032..d00846a4c3e02 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -73,11 +73,3 @@ func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>)
   %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
 }
-
-// -----
-
-func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
-  // expected-error@+1 {{failed to legalize operation 'tosa.mul'}}
-  %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
-  return %0 : tensor<2x3xi32>
-}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index fb912e49ff920..aee0caa91043d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -2471,3 +2471,14 @@ func.func @test_0d_input(%arg0: tensor<i32>) -> () {
 
   return
 }
+
+// -----
+
+// CHECK-LABEL: @mul_no_const_shift
+func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
+  // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i8, %[[OUT:.*]]: i32):
+  // CHECK: tosa.apply_scale %[[ARG0]], %[[ARG1]], %[[ARG2]]
+  %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
+  return %0 : tensor<2x3xi32>
+}

@llvmbot
Copy link
Member

llvmbot commented Aug 25, 2025

@llvm/pr-subscribers-mlir

Author: None (ShivaChen)

Changes

The shift operand of tosa::MulOp could be non-constant when the dynamic extension enabled. Given that checkConstantOperandMul could check the shift operand according to the extension, we might able to relax the checking in TosaToLinalg.

Commutative of MulOp might need to be removed to avoid shift operand been reordered with other operands when the shift operand is non-constant.


Full diff: https://github.com/llvm/llvm-project/pull/155197.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (-1)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+40-16)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (-8)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+11)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 20889558be314..eed428da99192 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -983,7 +983,6 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
 def Tosa_MulOp : Tosa_Op<"mul", [
     DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
                               ["inferReturnTypeComponents"]>,
-    Commutative,
     Pure]> {
   let summary = "Multiplication operator.";
 
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0e3de067736c5..a02d6c97aa5d8 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -126,12 +126,12 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::MulOp>(op)) {
     auto shiftVal = cast<tosa::MulOp>(op).getShift();
     DenseElementsAttr shiftElem;
-    if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
-      (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
-      return nullptr;
-    }
-
-    int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+    bool shiftIsConstant = true;
+    int32_t shift = 0;
+    if (matchPattern(shiftVal, m_Constant(&shiftElem)))
+      shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+    else
+      shiftIsConstant = false;
 
     if (isa<FloatType>(elementTy)) {
       if (shift != 0) {
@@ -147,23 +147,24 @@ static Value createLinalgBodyCalculationForElementwiseOp(
       Value a = args[0];
       Value b = args[1];
 
-      if (shift > 0) {
-        auto shiftConst =
-            arith::ConstantIntOp::create(rewriter, loc, shift, /*bitwidth=*/8);
+      if (shift > 0 || !shiftIsConstant) {
+        Value shiftConst;
+        if (shiftIsConstant)
+          shiftConst =
+              rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
+
         if (!a.getType().isInteger(32))
           a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
 
         if (!b.getType().isInteger(32))
           b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
 
+        auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
         auto result = tosa::ApplyScaleOp::create(
-            rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
+            rewriter, loc, rewriter.getI32Type(), a, b, shiftAmount,
             rewriter.getStringAttr("SINGLE_ROUND"));
 
-        if (elementTy.isInteger(32))
-          return result;
-
-        return arith::TruncIOp::create(rewriter, loc, elementTy, result);
+        return result;
       }
 
       int aWidth = a.getType().getIntOrFloatBitWidth();
@@ -909,6 +910,20 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
   return operand;
 }
 
+static bool hasDynamicDimensions(ValueRange operands) {
+  for (auto operand : operands) {
+    auto rankedTensorType = cast_or_null<RankedTensorType>(operand.getType());
+    if (!rankedTensorType)
+      continue;
+    int64_t rank = rankedTensorType.getRank();
+    for (auto dim : llvm::seq<int64_t>(0, rank)) {
+      if (rankedTensorType.isDynamicDim(dim))
+        return true;
+    }
+  }
+  return false;
+}
+
 static SmallVector<Value>
 broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
                            IndexPool &indexPool, ValueRange operands,
@@ -918,6 +933,9 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
   if (operands.size() == 1)
     return operands;
 
+  if (!hasDynamicDimensions(operands))
+    return operands;
+
   // Broadcast dynamic dimensions operand by operand
   return llvm::map_to_vector(operands, [&](Value operand) {
     return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
@@ -990,8 +1008,14 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
 static ValueRange getBroadcastableOperands(Operation *operation,
                                            ValueRange operands) {
   // Shift cannot broadcast
-  if (isa<tosa::MulOp>(operation))
-    return operands.take_front(2);
+  if (isa<tosa::MulOp>(operation)) {
+    DenseElementsAttr shiftElems;
+    // Shift cannot broadcast when it is constant
+    if (matchPattern(operation->getOperand(2), m_Constant(&shiftElems)))
+      return operands.take_front(2);
+    else
+      return operands.take_front(3);
+  }
   // Input1_zp and output_zp cannot broadcast
   if (isa<tosa::NegateOp>(operation))
     return operands.take_front(1);
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 69d8471df8032..d00846a4c3e02 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -73,11 +73,3 @@ func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>)
   %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
 }
-
-// -----
-
-func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
-  // expected-error@+1 {{failed to legalize operation 'tosa.mul'}}
-  %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
-  return %0 : tensor<2x3xi32>
-}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index fb912e49ff920..aee0caa91043d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -2471,3 +2471,14 @@ func.func @test_0d_input(%arg0: tensor<i32>) -> () {
 
   return
 }
+
+// -----
+
+// CHECK-LABEL: @mul_no_const_shift
+func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
+  // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i8, %[[OUT:.*]]: i32):
+  // CHECK: tosa.apply_scale %[[ARG0]], %[[ARG1]], %[[ARG2]]
+  %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
+  return %0 : tensor<2x3xi32>
+}

@ShivaChen ShivaChen requested review from lhutton1 and sjarus August 27, 2025 01:06
Copy link
Contributor

@sjarus sjarus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks !

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we could just use something like

const auto tType = dyn_cast<RankedTensorType>(op.getType());
if (tType && !tType.hasStaticShape())

here to avoid declaring hasDynamicDimensions

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks nicer, thanks.

Copy link
Contributor

@lhutton1 lhutton1 Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this might be better removed in a separate PR incase it breaks any assumptions (this way it can be cleanly reverted/bisected if there are any problems). e.g. I suspect when shift is constant, foldCommutative works correctly - it's possible there are patterns based on this assumption

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to be reasonable consideration and easier to catch issue in the future. I will add back Commutative. Thanks for catching this.

The shift operand of tosa::MulOp could be non-constant when
the dynamic extension enabled. Given that checkConstantOperandMul
could check the shift operand according to the extension, we
might able to relax the checking in TosaToLinalg.

Commutative of MulOp might need to be removed to avoid shift
operand been reordered with other operands when the shift operand
is non-constant.
@ShivaChen
Copy link
Collaborator Author

Looks good to me, thanks !

Thanks for review :-)

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ShivaChen!

@lhutton1 lhutton1 merged commit 6926a6b into llvm:main Aug 28, 2025
9 checks passed
t-a-james pushed a commit to t-a-james/llvm-project that referenced this pull request Aug 28, 2025
…oject into bugprone-method-hiding

* 'bugprone-method-hiding' of github.com:t-a-james/llvm-project: (230 commits)
  [SimplifyCFG] Move token type check into canReplaceOperandWithVariable()
  [ADT] Fix signed integer overflow (llvm#155826)
  [Offload] Update LIBOMPTARGET_INFO text for `attach` map-type. (llvm#155509)
  [CMake][AIX] Enable CMP0182: Create shared library archives by default (llvm#155686)
  AMDGPU: Add tests for atomics with AGPR operands (llvm#155820)
  [AArch64] Split zero cycle zeoring per register class (llvm#154561)
  [gn build] Port fa883e1
  [mlir][tosa] Allow shift operand of tosa::MulOp as non-constant (llvm#155197)
  [AArch64][NFC] Add MCInstrAnalysis unittests (llvm#155609)
  [Offload][OpenMP] Tests require libc on GPU for printf (llvm#155785)
  AMDGPU: Add missing verifier tests for load/store AGPR case (llvm#155815)
  [lldb-mcp] Fix building for Windows
  Revert "[lldb] Correct a usage after a rename was merged. (llvm#155720)"
  Revert "[lldb] NFC Moving mcp::Transport into its own file. (llvm#155711)"
  [lldb][test] Run ranges::ref_vew test only for libc++ (llvm#155813)
  [SCCP][FuncSpec] Poison unreachable constant global variable user (llvm#155753)
  [LoongArch] Lowering v32i8 vector mask generation to `VMSKLTZ` (llvm#149953)
  [flang][docs][NFC] Remove stray backtick (llvm#154974)
  [MLIR] Apply clang-tidy fixes for misc-use-internal-linkage in LinalgOps.cpp (NFC)
  [MLIR] Apply clang-tidy fixes for performance-move-const-arg in VariantValue.cpp (NFC)
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants