-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][affine] fix the issue of ceildiv-mul-ceildiv form expression n… #111254
Conversation
…ot satisfying commutative Fixes llvm#107508
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-affine Author: long.chen (lipracer) Changes…ot satisfying commutative Fixes #107508 Full diff: https://github.com/llvm/llvm-project/pull/111254.diff 2 Files Affected:
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 0b078966aeb85b..f947b8c3d54c6a 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -349,6 +349,8 @@ unsigned AffineDimExpr::getPosition() const {
return static_cast<ImplType *>(expr)->position;
}
+namespace {
+
/// Returns true if the expression is divisible by the given symbol with
/// position `symbolPos`. The argument `opKind` specifies here what kind of
/// division or mod operation called this division. It helps in implementing the
@@ -356,12 +358,17 @@ unsigned AffineDimExpr::getPosition() const {
///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
/// operation, then the commutative property can be used otherwise, the floordiv
/// operation is not divisible. The same argument holds for ceildiv operation.
-static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
- AffineExprKind opKind) {
+bool isDivisibleBySymbolImpl(AffineExpr expr, unsigned symbolPos,
+ AffineExprKind opKind,
+ SmallVectorImpl<AffineExpr> &visitedExprs,
+ size_t depth = 0) {
// The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
opKind == AffineExprKind::CeilDiv) &&
"unexpected opKind");
+ if (visitedExprs.size() > depth)
+ visitedExprs.resize(depth);
+ visitedExprs.emplace_back(expr);
switch (expr.getKind()) {
case AffineExprKind::Constant:
return cast<AffineConstantExpr>(expr).getValue() == 0;
@@ -372,8 +379,10 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
// Checks divisibility by the given symbol for both operands.
case AffineExprKind::Add: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
- isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
+ return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos, opKind,
+ visitedExprs, depth + 1) &&
+ isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos, opKind,
+ visitedExprs, depth + 1);
}
// Checks divisibility by the given symbol for both operands. Consider the
// expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
@@ -382,16 +391,20 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
// `AffineExprKind::Mod` for this reason.
case AffineExprKind::Mod: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
- AffineExprKind::Mod) &&
- isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
- AffineExprKind::Mod);
+ return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos,
+ AffineExprKind::Mod, visitedExprs,
+ depth + 1) &&
+ isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos,
+ AffineExprKind::Mod, visitedExprs,
+ depth + 1);
}
// Checks if any of the operand divisible by the given symbol.
case AffineExprKind::Mul: {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
- isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
+ return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos, opKind,
+ visitedExprs, depth + 1) ||
+ isDivisibleBySymbolImpl(binaryExpr.getRHS(), symbolPos, opKind,
+ visitedExprs, depth + 1);
}
// Floordiv and ceildiv are divisible by the given symbol when the first
// operand is divisible, and the affine expression kind of the argument expr
@@ -406,12 +419,25 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
if (opKind != expr.getKind())
return false;
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
+ if (llvm::any_of(visitedExprs, [](auto expr) {
+ return expr.getKind() == AffineExprKind::Mul;
+ }))
+ return false;
+ return isDivisibleBySymbolImpl(binaryExpr.getLHS(), symbolPos,
+ expr.getKind(), visitedExprs, depth + 1);
}
}
llvm_unreachable("Unknown AffineExpr");
}
+bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
+ AffineExprKind opKind) {
+ SmallVector<AffineExpr> visitedExprs;
+ return isDivisibleBySymbolImpl(expr, symbolPos, opKind, visitedExprs);
+}
+
+} // namespace
+
/// Divides the given expression by the given symbol at position `symbolPos`. It
/// considers the divisibility condition is checked before calling itself. A
/// null expression is returned whenever the divisibility condition fails.
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 92d3d86bc93068..d1f34f20fa5dad 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
}
// Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
-// CHECK-LABEL: func @semiaffine_composite_floor
-func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index {
+ %a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
+ // CHECK: %[[CST:.*]] = arith.constant 43
+ return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation.
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
- // CHECK: %[[CST:.*]] = arith.constant 47
+ // CHECK-NOT: arith.constant
+ return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation
+// CHECK-LABEL: func @semiaffine_composite_floordiv
+func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index {
+ %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1]
+ // CHECK-NOT: arith.constant
return %a : index
}
|
we can simple let's prove the similarly, it can be proven that: because |
…ot satisfying commutative
Fixes #107508