Skip to content

[mlir][affine] Add an integer range interface to affine.apply#174277

Merged
Hardcode84 merged 3 commits intollvm:mainfrom
Hardcode84:affine-int-range
Jan 7, 2026
Merged

[mlir][affine] Add an integer range interface to affine.apply#174277
Hardcode84 merged 3 commits intollvm:mainfrom
Hardcode84:affine-int-range

Conversation

@Hardcode84
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jan 3, 2026

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/174277.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+3-1)
  • (modified) mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h (+7)
  • (modified) mlir/lib/Dialect/Affine/IR/CMakeLists.txt (+2)
  • (added) mlir/lib/Dialect/Affine/IR/InferIntRangeInterfaceImpls.cpp (+38)
  • (modified) mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (+99)
  • (added) mlir/test/Dialect/Affine/int-range-interface.mlir (+113)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 409bd05292e0d..bd14f6ff4c5aa 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -16,6 +16,7 @@
 include "mlir/Dialect/Arith/IR/ArithBase.td"
 include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -35,7 +36,8 @@ class Affine_Op<string mnemonic, list<Trait> traits = []> :
 def ImplicitAffineTerminator
     : SingleBlockImplicitTerminator<"AffineYieldOp">;
 
-def AffineApplyOp : Affine_Op<"apply", [Pure]> {
+def AffineApplyOp : Affine_Op<"apply",
+    [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
   let summary = "affine apply operation";
   let description = [{
     The `affine.apply` operation applies an [affine mapping](#affine-maps)
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index e46358ccfc46f..e369c80a26ea9 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -20,6 +20,7 @@
 #include <optional>
 
 namespace mlir {
+class AffineExpr;
 class ShapedDimOpInterface;
 
 namespace intrange {
@@ -151,6 +152,12 @@ std::optional<bool> evaluatePred(CmpPredicate pred,
 ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op,
                                             const IntegerValueRange &maybeDim);
 
+/// Infer the integer range for an affine expression given ranges for its
+/// dimensions and symbols.
+ConstantIntRanges inferAffineExpr(AffineExpr expr,
+                                  ArrayRef<ConstantIntRanges> dimRanges,
+                                  ArrayRef<ConstantIntRanges> symbolRanges);
+
 } // namespace intrange
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
index 7f7a01be891e0..566bc060e5d38 100644
--- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRAffineDialect
   AffineMemoryOpInterfaces.cpp
   AffineOps.cpp
   AffineValueMap.cpp
+  InferIntRangeInterfaceImpls.cpp
   ValueBoundsOpInterfaceImpl.cpp
 
   ADDITIONAL_HEADER_DIRS
@@ -15,6 +16,7 @@ add_mlir_dialect_library(MLIRAffineDialect
   MLIRArithDialect
   MLIRDialectUtils
   MLIRIR
+  MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRLoopLikeInterface
   MLIRMemRefDialect
diff --git a/mlir/lib/Dialect/Affine/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Affine/IR/InferIntRangeInterfaceImpls.cpp
new file mode 100644
index 0000000000000..1bf6829a04b6e
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/IR/InferIntRangeInterfaceImpls.cpp
@@ -0,0 +1,38 @@
+//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for affine --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
+
+using namespace mlir;
+using namespace mlir::affine;
+using namespace mlir::intrange;
+
+//===----------------------------------------------------------------------===//
+// AffineApplyOp
+//===----------------------------------------------------------------------===//
+
+void AffineApplyOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                      SetIntRangeFn setResultRange) {
+  AffineMap map = getAffineMap();
+
+  // Split operand ranges into dimensions and symbols.
+  unsigned numDims = map.getNumDims();
+  ArrayRef<ConstantIntRanges> dimRanges = argRanges.take_front(numDims);
+  ArrayRef<ConstantIntRanges> symbolRanges = argRanges.drop_front(numDims);
+
+  // Affine maps should have exactly one result for affine.apply.
+  assert(map.getNumResults() == 1 && "affine.apply must have single result");
+
+  // Infer the range for the affine expression.
+  ConstantIntRanges resultRange =
+      inferAffineExpr(map.getResult(0), dimRanges, symbolRanges);
+
+  setResultRange(getResult(), resultRange);
+}
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 0f28cbc751c1c..49977a0a5fc27 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 
@@ -768,3 +769,101 @@ mlir::intrange::inferShapedDimOpInterface(ShapedDimOpInterface op,
   }
   return result.value_or(ConstantIntRanges::fromSigned(zero, typeMax));
 }
+
+//===----------------------------------------------------------------------===//
+// Affine expression inference
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferAffineExpr(AffineExpr expr,
+                                ArrayRef<ConstantIntRanges> dimRanges,
+                                ArrayRef<ConstantIntRanges> symbolRanges) {
+  switch (expr.getKind()) {
+  case AffineExprKind::Constant: {
+    auto constExpr = cast<AffineConstantExpr>(expr);
+    APInt value(indexMaxWidth, constExpr.getValue(), /*isSigned=*/true);
+    return ConstantIntRanges::constant(value);
+  }
+  case AffineExprKind::DimId: {
+    auto dimExpr = cast<AffineDimExpr>(expr);
+    unsigned pos = dimExpr.getPosition();
+    assert(pos < dimRanges.size() && "Dimension index out of bounds");
+    return dimRanges[pos];
+  }
+  case AffineExprKind::SymbolId: {
+    auto symbolExpr = cast<AffineSymbolExpr>(expr);
+    unsigned pos = symbolExpr.getPosition();
+    assert(pos < symbolRanges.size() && "Symbol index out of bounds");
+    return symbolRanges[pos];
+  }
+  case AffineExprKind::Add: {
+    auto binExpr = cast<AffineBinaryOpExpr>(expr);
+    ConstantIntRanges lhs =
+        inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
+    ConstantIntRanges rhs =
+        inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
+    return inferAdd({lhs, rhs}, OverflowFlags::Nsw);
+  }
+  case AffineExprKind::Mul: {
+    auto binExpr = cast<AffineBinaryOpExpr>(expr);
+    ConstantIntRanges lhs =
+        inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
+    ConstantIntRanges rhs =
+        inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
+    return inferMul({lhs, rhs}, OverflowFlags::Nsw);
+  }
+  case AffineExprKind::Mod: {
+    auto binExpr = cast<AffineBinaryOpExpr>(expr);
+    ConstantIntRanges lhs =
+        inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
+    ConstantIntRanges rhs =
+        inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
+    // Affine mod is Euclidean modulo: result is always in [0, rhs_max-1].
+    // This assumes RHS is positive (enforced by affine expr semantics).
+    unsigned width = rhs.smin().getBitWidth();
+    APInt zero = APInt::getZero(width);
+    APInt maxRhs = rhs.umax();
+    if (maxRhs.isZero())
+      return ConstantIntRanges::maxRange(width);
+    APInt upper = maxRhs - 1;
+    return ConstantIntRanges::fromUnsigned(zero, upper);
+  }
+  case AffineExprKind::FloorDiv: {
+    auto binExpr = cast<AffineBinaryOpExpr>(expr);
+    ConstantIntRanges lhs =
+        inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
+    ConstantIntRanges rhs =
+        inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
+    // Affine floordiv requires strictly positive divisor (> 0).
+    // Clamp divisor lower bound to 1 for tighter range inference.
+    unsigned width = rhs.smin().getBitWidth();
+    APInt one(width, 1);
+    APInt clampedUMin = rhs.umin().ult(one) ? one : rhs.umin();
+    APInt clampedSMin = rhs.smin().slt(one) ? one : rhs.smin();
+    ConstantIntRanges clampedRhs =
+        ConstantIntRanges::fromUnsigned(clampedUMin, rhs.umax())
+            .intersection(
+                ConstantIntRanges::fromSigned(clampedSMin, rhs.smax()));
+    return inferFloorDivS({lhs, clampedRhs});
+  }
+  case AffineExprKind::CeilDiv: {
+    auto binExpr = cast<AffineBinaryOpExpr>(expr);
+    ConstantIntRanges lhs =
+        inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
+    ConstantIntRanges rhs =
+        inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
+    // Affine ceildiv requires strictly positive divisor (> 0).
+    // Clamp divisor lower bound to 1 for tighter range inference.
+    unsigned width = rhs.smin().getBitWidth();
+    APInt one(width, 1);
+    APInt clampedUMin = rhs.umin().ult(one) ? one : rhs.umin();
+    APInt clampedSMin = rhs.smin().slt(one) ? one : rhs.smin();
+    ConstantIntRanges clampedRhs =
+        ConstantIntRanges::fromUnsigned(clampedUMin, rhs.umax())
+            .intersection(
+                ConstantIntRanges::fromSigned(clampedSMin, rhs.smax()));
+    return inferCeilDivS({lhs, clampedRhs});
+  }
+  }
+  llvm_unreachable("unknown affine expression kind");
+}
diff --git a/mlir/test/Dialect/Affine/int-range-interface.mlir b/mlir/test/Dialect/Affine/int-range-interface.mlir
new file mode 100644
index 0000000000000..85a83f318aa01
--- /dev/null
+++ b/mlir/test/Dialect/Affine/int-range-interface.mlir
@@ -0,0 +1,113 @@
+// RUN: mlir-opt --int-range-optimizations %s | FileCheck %s
+
+// CHECK-LABEL: func @affine_apply_constant
+// CHECK: test.reflect_bounds {smax = 42 : index, smin = 42 : index, umax = 42 : index, umin = 42 : index}
+func.func @affine_apply_constant() -> index {
+  %0 = affine.apply affine_map<() -> (42)>()
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @affine_apply_add
+// CHECK: test.reflect_bounds {smax = 15 : index, smin = 6 : index, umax = 15 : index, umin = 6 : index}
+func.func @affine_apply_add() -> index {
+  %d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
+                           smin = 2 : index, smax = 5 : index } : index
+  %d1 = test.with_bounds { umin = 4 : index, umax = 10 : index,
+                           smin = 4 : index, smax = 10 : index } : index
+  %0 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%d0, %d1)
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @affine_apply_mul
+// CHECK: test.reflect_bounds {smax = 30 : index, smin = 12 : index, umax = 30 : index, umin = 12 : index}
+func.func @affine_apply_mul() -> index {
+  %d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
+                           smin = 2 : index, smax = 5 : index } : index
+  %s0 = test.with_bounds { umin = 6 : index, umax = 6 : index,
+                           smin = 6 : index, smax = 6 : index } : index
+  %0 = affine.apply affine_map<(d0)[s0] -> (d0 * s0)>(%d0)[%s0]
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @affine_apply_floordiv
+// CHECK: test.reflect_bounds {smax = 2 : index, smin = 1 : index, umax = 2 : index, umin = 1 : index}
+func.func @affine_apply_floordiv() -> index {
+  %d0 = test.with_bounds { umin = 5 : index, umax = 10 : index,
+                           smin = 5 : index, smax = 10 : index } : index
+  %0 = affine.apply affine_map<(d0) -> (d0 floordiv 4)>(%d0)
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @affine_apply_ceildiv
+// CHECK: test.reflect_bounds {smax = 3 : index, smin = 2 : index, umax = 3 : index, umin = 2 : index}
+func.func @affine_apply_ceildiv() -> index {
+  %d0 = test.with_bounds { umin = 5 : index, umax = 10 : index,
+                           smin = 5 : index, smax = 10 : index } : index
+  %0 = affine.apply affine_map<(d0) -> (d0 ceildiv 4)>(%d0)
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @affine_apply_mod
+// CHECK: test.reflect_bounds {smax = 3 : index, smin = 0 : index, umax = 3 : index, umin = 0 : index}
+func.func @affine_apply_mod() -> index {
+  %d0 = test.with_bounds { umin = 5 : index, umax = 27 : index,
+                           smin = 5 : index, smax = 27 : index } : index
+  %0 = affine.apply affine_map<(d0) -> (d0 mod 4)>(%d0)
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @affine_apply_complex
+// CHECK: test.reflect_bounds {smax = 13 : index, smin = 5 : index, umax = 13 : index, umin = 5 : index}
+func.func @affine_apply_complex() -> index {
+  %d0 = test.with_bounds { umin = 10 : index, umax = 20 : index,
+                           smin = 10 : index, smax = 20 : index } : index
+  %d1 = test.with_bounds { umin = 3 : index, umax = 7 : index,
+                           smin = 3 : index, smax = 7 : index } : index
+  // (d0 floordiv 2) + (d1 mod 4) = [5, 10] + [0, 3] = [5, 13]
+  %0 = affine.apply affine_map<(d0, d1) -> (d0 floordiv 2 + d1 mod 4)>(%d0, %d1)
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @affine_apply_with_symbols
+// CHECK: test.reflect_bounds {smax = 24 : index, smin = 9 : index, umax = 24 : index, umin = 9 : index}
+func.func @affine_apply_with_symbols() -> index {
+  %d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
+                           smin = 2 : index, smax = 5 : index } : index
+  %s0 = test.with_bounds { umin = 3 : index, umax = 4 : index,
+                           smin = 3 : index, smax = 4 : index } : index
+  // d0 * s0 + s0 = s0 * (d0 + 1) = [3, 4] * [3, 6] = [9, 24]
+  %0 = affine.apply affine_map<(d0)[s0] -> (d0 * s0 + s0)>(%d0)[%s0]
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @affine_apply_sub
+// CHECK: test.reflect_bounds {smax = 1 : index, smin = -8 : index
+func.func @affine_apply_sub() -> index {
+  %d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
+                           smin = 2 : index, smax = 5 : index } : index
+  %d1 = test.with_bounds { umin = 4 : index, umax = 10 : index,
+                           smin = 4 : index, smax = 10 : index } : index
+  // d0 - d1 = [2, 5] - [4, 10] = [2-10, 5-4] = [-8, 1]
+  %0 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%d0, %d1)
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @affine_apply_mul_constant
+// CHECK: test.reflect_bounds {smax = 20 : index, smin = 8 : index, umax = 20 : index, umin = 8 : index}
+func.func @affine_apply_mul_constant() -> index {
+  %d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
+                           smin = 2 : index, smax = 5 : index } : index
+  // d0 * 4 = [2, 5] * 4 = [8, 20]
+  %0 = affine.apply affine_map<(d0) -> (d0 * 4)>(%d0)
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}

@llvmbot
Copy link
Member

llvmbot commented Jan 3, 2026

@llvm/pr-subscribers-mlir-affine

Author: Ivan Butygin (Hardcode84)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/174277.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+3-1)
  • (modified) mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h (+7)
  • (modified) mlir/lib/Dialect/Affine/IR/CMakeLists.txt (+2)
  • (added) mlir/lib/Dialect/Affine/IR/InferIntRangeInterfaceImpls.cpp (+38)
  • (modified) mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (+99)
  • (added) mlir/test/Dialect/Affine/int-range-interface.mlir (+113)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 409bd05292e0d..bd14f6ff4c5aa 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -16,6 +16,7 @@
 include "mlir/Dialect/Arith/IR/ArithBase.td"
 include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -35,7 +36,8 @@ class Affine_Op<string mnemonic, list<Trait> traits = []> :
 def ImplicitAffineTerminator
     : SingleBlockImplicitTerminator<"AffineYieldOp">;
 
-def AffineApplyOp : Affine_Op<"apply", [Pure]> {
+def AffineApplyOp : Affine_Op<"apply",
+    [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
   let summary = "affine apply operation";
   let description = [{
     The `affine.apply` operation applies an [affine mapping](#affine-maps)
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index e46358ccfc46f..e369c80a26ea9 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -20,6 +20,7 @@
 #include <optional>
 
 namespace mlir {
+class AffineExpr;
 class ShapedDimOpInterface;
 
 namespace intrange {
@@ -151,6 +152,12 @@ std::optional<bool> evaluatePred(CmpPredicate pred,
 ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op,
                                             const IntegerValueRange &maybeDim);
 
+/// Infer the integer range for an affine expression given ranges for its
+/// dimensions and symbols.
+ConstantIntRanges inferAffineExpr(AffineExpr expr,
+                                  ArrayRef<ConstantIntRanges> dimRanges,
+                                  ArrayRef<ConstantIntRanges> symbolRanges);
+
 } // namespace intrange
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
index 7f7a01be891e0..566bc060e5d38 100644
--- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRAffineDialect
   AffineMemoryOpInterfaces.cpp
   AffineOps.cpp
   AffineValueMap.cpp
+  InferIntRangeInterfaceImpls.cpp
   ValueBoundsOpInterfaceImpl.cpp
 
   ADDITIONAL_HEADER_DIRS
@@ -15,6 +16,7 @@ add_mlir_dialect_library(MLIRAffineDialect
   MLIRArithDialect
   MLIRDialectUtils
   MLIRIR
+  MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRLoopLikeInterface
   MLIRMemRefDialect
diff --git a/mlir/lib/Dialect/Affine/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Affine/IR/InferIntRangeInterfaceImpls.cpp
new file mode 100644
index 0000000000000..1bf6829a04b6e
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/IR/InferIntRangeInterfaceImpls.cpp
@@ -0,0 +1,38 @@
+//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for affine --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
+
+using namespace mlir;
+using namespace mlir::affine;
+using namespace mlir::intrange;
+
+//===----------------------------------------------------------------------===//
+// AffineApplyOp
+//===----------------------------------------------------------------------===//
+
+void AffineApplyOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                      SetIntRangeFn setResultRange) {
+  AffineMap map = getAffineMap();
+
+  // Split operand ranges into dimensions and symbols.
+  unsigned numDims = map.getNumDims();
+  ArrayRef<ConstantIntRanges> dimRanges = argRanges.take_front(numDims);
+  ArrayRef<ConstantIntRanges> symbolRanges = argRanges.drop_front(numDims);
+
+  // Affine maps should have exactly one result for affine.apply.
+  assert(map.getNumResults() == 1 && "affine.apply must have single result");
+
+  // Infer the range for the affine expression.
+  ConstantIntRanges resultRange =
+      inferAffineExpr(map.getResult(0), dimRanges, symbolRanges);
+
+  setResultRange(getResult(), resultRange);
+}
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 0f28cbc751c1c..49977a0a5fc27 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 
@@ -768,3 +769,101 @@ mlir::intrange::inferShapedDimOpInterface(ShapedDimOpInterface op,
   }
   return result.value_or(ConstantIntRanges::fromSigned(zero, typeMax));
 }
+
+//===----------------------------------------------------------------------===//
+// Affine expression inference
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferAffineExpr(AffineExpr expr,
+                                ArrayRef<ConstantIntRanges> dimRanges,
+                                ArrayRef<ConstantIntRanges> symbolRanges) {
+  switch (expr.getKind()) {
+  case AffineExprKind::Constant: {
+    auto constExpr = cast<AffineConstantExpr>(expr);
+    APInt value(indexMaxWidth, constExpr.getValue(), /*isSigned=*/true);
+    return ConstantIntRanges::constant(value);
+  }
+  case AffineExprKind::DimId: {
+    auto dimExpr = cast<AffineDimExpr>(expr);
+    unsigned pos = dimExpr.getPosition();
+    assert(pos < dimRanges.size() && "Dimension index out of bounds");
+    return dimRanges[pos];
+  }
+  case AffineExprKind::SymbolId: {
+    auto symbolExpr = cast<AffineSymbolExpr>(expr);
+    unsigned pos = symbolExpr.getPosition();
+    assert(pos < symbolRanges.size() && "Symbol index out of bounds");
+    return symbolRanges[pos];
+  }
+  case AffineExprKind::Add: {
+    auto binExpr = cast<AffineBinaryOpExpr>(expr);
+    ConstantIntRanges lhs =
+        inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
+    ConstantIntRanges rhs =
+        inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
+    return inferAdd({lhs, rhs}, OverflowFlags::Nsw);
+  }
+  case AffineExprKind::Mul: {
+    auto binExpr = cast<AffineBinaryOpExpr>(expr);
+    ConstantIntRanges lhs =
+        inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
+    ConstantIntRanges rhs =
+        inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
+    return inferMul({lhs, rhs}, OverflowFlags::Nsw);
+  }
+  case AffineExprKind::Mod: {
+    auto binExpr = cast<AffineBinaryOpExpr>(expr);
+    ConstantIntRanges lhs =
+        inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
+    ConstantIntRanges rhs =
+        inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
+    // Affine mod is Euclidean modulo: result is always in [0, rhs_max-1].
+    // This assumes RHS is positive (enforced by affine expr semantics).
+    unsigned width = rhs.smin().getBitWidth();
+    APInt zero = APInt::getZero(width);
+    APInt maxRhs = rhs.umax();
+    if (maxRhs.isZero())
+      return ConstantIntRanges::maxRange(width);
+    APInt upper = maxRhs - 1;
+    return ConstantIntRanges::fromUnsigned(zero, upper);
+  }
+  case AffineExprKind::FloorDiv: {
+    auto binExpr = cast<AffineBinaryOpExpr>(expr);
+    ConstantIntRanges lhs =
+        inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
+    ConstantIntRanges rhs =
+        inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
+    // Affine floordiv requires strictly positive divisor (> 0).
+    // Clamp divisor lower bound to 1 for tighter range inference.
+    unsigned width = rhs.smin().getBitWidth();
+    APInt one(width, 1);
+    APInt clampedUMin = rhs.umin().ult(one) ? one : rhs.umin();
+    APInt clampedSMin = rhs.smin().slt(one) ? one : rhs.smin();
+    ConstantIntRanges clampedRhs =
+        ConstantIntRanges::fromUnsigned(clampedUMin, rhs.umax())
+            .intersection(
+                ConstantIntRanges::fromSigned(clampedSMin, rhs.smax()));
+    return inferFloorDivS({lhs, clampedRhs});
+  }
+  case AffineExprKind::CeilDiv: {
+    auto binExpr = cast<AffineBinaryOpExpr>(expr);
+    ConstantIntRanges lhs =
+        inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
+    ConstantIntRanges rhs =
+        inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
+    // Affine ceildiv requires strictly positive divisor (> 0).
+    // Clamp divisor lower bound to 1 for tighter range inference.
+    unsigned width = rhs.smin().getBitWidth();
+    APInt one(width, 1);
+    APInt clampedUMin = rhs.umin().ult(one) ? one : rhs.umin();
+    APInt clampedSMin = rhs.smin().slt(one) ? one : rhs.smin();
+    ConstantIntRanges clampedRhs =
+        ConstantIntRanges::fromUnsigned(clampedUMin, rhs.umax())
+            .intersection(
+                ConstantIntRanges::fromSigned(clampedSMin, rhs.smax()));
+    return inferCeilDivS({lhs, clampedRhs});
+  }
+  }
+  llvm_unreachable("unknown affine expression kind");
+}
diff --git a/mlir/test/Dialect/Affine/int-range-interface.mlir b/mlir/test/Dialect/Affine/int-range-interface.mlir
new file mode 100644
index 0000000000000..85a83f318aa01
--- /dev/null
+++ b/mlir/test/Dialect/Affine/int-range-interface.mlir
@@ -0,0 +1,113 @@
+// RUN: mlir-opt --int-range-optimizations %s | FileCheck %s
+
+// CHECK-LABEL: func @affine_apply_constant
+// CHECK: test.reflect_bounds {smax = 42 : index, smin = 42 : index, umax = 42 : index, umin = 42 : index}
+func.func @affine_apply_constant() -> index {
+  %0 = affine.apply affine_map<() -> (42)>()
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @affine_apply_add
+// CHECK: test.reflect_bounds {smax = 15 : index, smin = 6 : index, umax = 15 : index, umin = 6 : index}
+func.func @affine_apply_add() -> index {
+  %d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
+                           smin = 2 : index, smax = 5 : index } : index
+  %d1 = test.with_bounds { umin = 4 : index, umax = 10 : index,
+                           smin = 4 : index, smax = 10 : index } : index
+  %0 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%d0, %d1)
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @affine_apply_mul
+// CHECK: test.reflect_bounds {smax = 30 : index, smin = 12 : index, umax = 30 : index, umin = 12 : index}
+func.func @affine_apply_mul() -> index {
+  %d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
+                           smin = 2 : index, smax = 5 : index } : index
+  %s0 = test.with_bounds { umin = 6 : index, umax = 6 : index,
+                           smin = 6 : index, smax = 6 : index } : index
+  %0 = affine.apply affine_map<(d0)[s0] -> (d0 * s0)>(%d0)[%s0]
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @affine_apply_floordiv
+// CHECK: test.reflect_bounds {smax = 2 : index, smin = 1 : index, umax = 2 : index, umin = 1 : index}
+func.func @affine_apply_floordiv() -> index {
+  %d0 = test.with_bounds { umin = 5 : index, umax = 10 : index,
+                           smin = 5 : index, smax = 10 : index } : index
+  %0 = affine.apply affine_map<(d0) -> (d0 floordiv 4)>(%d0)
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @affine_apply_ceildiv
+// CHECK: test.reflect_bounds {smax = 3 : index, smin = 2 : index, umax = 3 : index, umin = 2 : index}
+func.func @affine_apply_ceildiv() -> index {
+  %d0 = test.with_bounds { umin = 5 : index, umax = 10 : index,
+                           smin = 5 : index, smax = 10 : index } : index
+  %0 = affine.apply affine_map<(d0) -> (d0 ceildiv 4)>(%d0)
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @affine_apply_mod
+// CHECK: test.reflect_bounds {smax = 3 : index, smin = 0 : index, umax = 3 : index, umin = 0 : index}
+func.func @affine_apply_mod() -> index {
+  %d0 = test.with_bounds { umin = 5 : index, umax = 27 : index,
+                           smin = 5 : index, smax = 27 : index } : index
+  %0 = affine.apply affine_map<(d0) -> (d0 mod 4)>(%d0)
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @affine_apply_complex
+// CHECK: test.reflect_bounds {smax = 13 : index, smin = 5 : index, umax = 13 : index, umin = 5 : index}
+func.func @affine_apply_complex() -> index {
+  %d0 = test.with_bounds { umin = 10 : index, umax = 20 : index,
+                           smin = 10 : index, smax = 20 : index } : index
+  %d1 = test.with_bounds { umin = 3 : index, umax = 7 : index,
+                           smin = 3 : index, smax = 7 : index } : index
+  // (d0 floordiv 2) + (d1 mod 4) = [5, 10] + [0, 3] = [5, 13]
+  %0 = affine.apply affine_map<(d0, d1) -> (d0 floordiv 2 + d1 mod 4)>(%d0, %d1)
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @affine_apply_with_symbols
+// CHECK: test.reflect_bounds {smax = 24 : index, smin = 9 : index, umax = 24 : index, umin = 9 : index}
+func.func @affine_apply_with_symbols() -> index {
+  %d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
+                           smin = 2 : index, smax = 5 : index } : index
+  %s0 = test.with_bounds { umin = 3 : index, umax = 4 : index,
+                           smin = 3 : index, smax = 4 : index } : index
+  // d0 * s0 + s0 = s0 * (d0 + 1) = [3, 4] * [3, 6] = [9, 24]
+  %0 = affine.apply affine_map<(d0)[s0] -> (d0 * s0 + s0)>(%d0)[%s0]
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @affine_apply_sub
+// CHECK: test.reflect_bounds {smax = 1 : index, smin = -8 : index
+func.func @affine_apply_sub() -> index {
+  %d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
+                           smin = 2 : index, smax = 5 : index } : index
+  %d1 = test.with_bounds { umin = 4 : index, umax = 10 : index,
+                           smin = 4 : index, smax = 10 : index } : index
+  // d0 - d1 = [2, 5] - [4, 10] = [2-10, 5-4] = [-8, 1]
+  %0 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%d0, %d1)
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @affine_apply_mul_constant
+// CHECK: test.reflect_bounds {smax = 20 : index, smin = 8 : index, umax = 20 : index, umin = 8 : index}
+func.func @affine_apply_mul_constant() -> index {
+  %d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
+                           smin = 2 : index, smax = 5 : index } : index
+  // d0 * 4 = [2, 5] * 4 = [8, 20]
+  %0 = affine.apply affine_map<(d0) -> (d0 * 4)>(%d0)
+  %1 = test.reflect_bounds %0 : index
+  func.return %1 : index
+}

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, and we might want to get affine.linearize_index and affine.delinearize_index in future PRs

@krzysz00 krzysz00 self-requested a review January 6, 2026 17:14
@Hardcode84 Hardcode84 merged commit 4dc9a0e into llvm:main Jan 7, 2026
10 checks passed
@Hardcode84 Hardcode84 deleted the affine-int-range branch January 7, 2026 00:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants