Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][affine] fix the issue of ceildiv-mul-ceildiv form expression n… #111254

Merged
merged 2 commits into from
Oct 12, 2024

Conversation

lipracer
Copy link
Member

@lipracer lipracer commented Oct 5, 2024

…ot satisfying commutative

Fixes #107508

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:affine mlir labels Oct 5, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 5, 2024

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-affine

Author: long.chen (lipracer)

Changes

…ot satisfying commutative

Fixes #107508


Full diff: https://github.com/llvm/llvm-project/pull/111254.diff

2 Files Affected:

  • (modified) mlir/lib/IR/AffineExpr.cpp (+37-11)
  • (modified) mlir/test/Dialect/Affine/simplify-structures.mlir (+19-3)
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 0b078966aeb85b..f947b8c3d54c6a 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -349,6 +349,8 @@ unsigned AffineDimExpr::getPosition() const {
   return static_cast<ImplType *>(expr)->position;
 }
 
+namespace {
+
 /// Returns true if the expression is divisible by the given symbol with
 /// position `symbolPos`. The argument `opKind` specifies here what kind of
 /// division or mod operation called this division. It helps in implementing the
@@ -356,12 +358,17 @@ unsigned AffineDimExpr::getPosition() const {
 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
 /// operation, then the commutative property can be used otherwise, the floordiv
 /// operation is not divisible. The same argument holds for ceildiv operation.
-static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
-                                AffineExprKind opKind) {
+bool isDivisibleBySymbolImpl(AffineExpr expr, unsigned symbolPos,
+                             AffineExprKind opKind,
+                             SmallVectorImpl<AffineExpr> &visitedExprs,
+                             size_t depth = 0) {
   // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
           opKind == AffineExprKind::CeilDiv) &&
          "unexpected opKind");
+  if (visitedExprs.size() > depth)
+    visitedExprs.resize(depth);
+  visitedExprs.emplace_back(expr);
   switch (expr.getKind()) {
   case AffineExprKind::Constant:
     return cast<AffineConstantExpr>(expr).getValue() == 0;
@@ -372,8 +379,10 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
   // Checks divisibility by the given symbol for both operands.
   case AffineExprKind::Add: {
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
+    return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos, opKind,
+                                   visitedExprs, depth + 1) &&
+           isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos, opKind,
+                                   visitedExprs, depth + 1);
   }
   // Checks divisibility by the given symbol for both operands. Consider the
   // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
@@ -382,16 +391,20 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
   // `AffineExprKind::Mod` for this reason.
   case AffineExprKind::Mod: {
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
-                               AffineExprKind::Mod) &&
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
-                               AffineExprKind::Mod);
+    return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos,
+                                   AffineExprKind::Mod, visitedExprs,
+                                   depth + 1) &&
+           isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos,
+                                   AffineExprKind::Mod, visitedExprs,
+                                   depth + 1);
   }
   // Checks if any of the operand divisible by the given symbol.
   case AffineExprKind::Mul: {
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
+    return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos, opKind,
+                                   visitedExprs, depth + 1) ||
+           isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos, opKind,
+                                   visitedExprs, depth + 1);
   }
   // Floordiv and ceildiv are divisible by the given symbol when the first
   // operand is divisible, and the affine expression kind of the argument expr
@@ -406,12 +419,25 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
     if (opKind != expr.getKind())
       return false;
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
+    if (llvm::any_of(visitedExprs, [](auto expr) {
+          return expr.getKind() == AffineExprKind::Mul;
+        }))
+      return false;
+    return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos,
+                                   expr.getKind(), visitedExprs, depth + 1);
   }
   }
   llvm_unreachable("Unknown AffineExpr");
 }
 
+bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
+                         AffineExprKind opKind) {
+  SmallVector<AffineExpr> visitedExprs;
+  return isDivisibleBySymbolImpl(expr, symbolPos, opKind, visitedExprs);
+}
+
+} // namespace
+
 /// Divides the given expression by the given symbol at position `symbolPos`. It
 /// considers the divisibility condition is checked before calling itself. A
 /// null expression is returned whenever the divisibility condition fails.
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 92d3d86bc93068..d1f34f20fa5dad 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
 }
 
 // Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
-// CHECK-LABEL: func @semiaffine_composite_floor
-func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
+  // CHECK:       %[[CST:.*]] = arith.constant 43
+  return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation.
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index {
   %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
-  // CHECK:       %[[CST:.*]] = arith.constant 47
+  // CHECK-NOT:       arith.constant
+  return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation
+// CHECK-LABEL: func @semiaffine_composite_floordiv
+func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1]
+  // CHECK-NOT:       arith.constant
   return %a : index
 }
 

mlir/lib/IR/AffineExpr.cpp Outdated Show resolved Hide resolved
mlir/lib/IR/AffineExpr.cpp Outdated Show resolved Hide resolved
@lipracer
Copy link
Member Author

lipracer commented Oct 9, 2024

we can simple (n * s) ceildiv a ceildiv s to n ceildiv a
because (n * s) ceildiv a ceildiv b <=> (n * s) ceildiv s ceildiv a
<=> n ceildiv a

let's prove the s floordiv a floor b <=> s floordiv b floor a
let s = ka +m (m < a) so s floordiv a <=> s / a - m / a

similarly, it can be proven that:
s floordiv a floordiv b <=> s / (a * b) - m / (a * b) - n / (b) constrain (n < b)
<=> s / (a * b) - (m + a*n) / (a*b)

because a* b - (m + a*n) <=> a*b - a*n - m > a - m > 0
so s floordiv a floordiv b <=> [s / (a*b)] <=> s floordiv b floordiv a
but if s floordiv b mutiply a factor above didn't always hold true.

@lipracer lipracer added the awaiting-review Has pending Phabricator review label Oct 10, 2024
@lipracer lipracer merged commit 51a2f50 into llvm:main Oct 12, 2024
10 checks passed
@lipracer lipracer deleted the fix-107508 branch October 12, 2024 15:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting-review Has pending Phabricator review mlir:affine mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir] Semantic inconsistency in ceildiv optimization
3 participants