Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand All @@ -35,7 +36,8 @@ class Affine_Op<string mnemonic, list<Trait> traits = []> :
def ImplicitAffineTerminator
: SingleBlockImplicitTerminator<"AffineYieldOp">;

def AffineApplyOp : Affine_Op<"apply", [Pure]> {
def AffineApplyOp : Affine_Op<"apply",
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "affine apply operation";
let description = [{
The `affine.apply` operation applies an [affine mapping](#affine-maps)
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <optional>

namespace mlir {
class AffineExpr;
class ShapedDimOpInterface;

namespace intrange {
Expand Down Expand Up @@ -151,6 +152,12 @@ std::optional<bool> evaluatePred(CmpPredicate pred,
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op,
const IntegerValueRange &maybeDim);

/// Infer the integer range for an affine expression given ranges for its
/// dimensions and symbols.
ConstantIntRanges inferAffineExpr(AffineExpr expr,
ArrayRef<ConstantIntRanges> dimRanges,
ArrayRef<ConstantIntRanges> symbolRanges);

} // namespace intrange
} // namespace mlir

Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Affine/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRAffineDialect
AffineMemoryOpInterfaces.cpp
AffineOps.cpp
AffineValueMap.cpp
InferIntRangeInterfaceImpls.cpp
ValueBoundsOpInterfaceImpl.cpp

ADDITIONAL_HEADER_DIRS
Expand All @@ -15,6 +16,7 @@ add_mlir_dialect_library(MLIRAffineDialect
MLIRArithDialect
MLIRDialectUtils
MLIRIR
MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRLoopLikeInterface
MLIRMemRefDialect
Expand Down
38 changes: 38 additions & 0 deletions mlir/lib/Dialect/Affine/IR/InferIntRangeInterfaceImpls.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for affine --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"

using namespace mlir;
using namespace mlir::affine;
using namespace mlir::intrange;

//===----------------------------------------------------------------------===//
// AffineApplyOp
//===----------------------------------------------------------------------===//

void AffineApplyOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
AffineMap map = getAffineMap();

// Split operand ranges into dimensions and symbols.
unsigned numDims = map.getNumDims();
ArrayRef<ConstantIntRanges> dimRanges = argRanges.take_front(numDims);
ArrayRef<ConstantIntRanges> symbolRanges = argRanges.drop_front(numDims);

// Affine maps should have exactly one result for affine.apply.
assert(map.getNumResults() == 1 && "affine.apply must have single result");

// Infer the range for the affine expression.
ConstantIntRanges resultRange =
inferAffineExpr(map.getResult(0), dimRanges, symbolRanges);

setResultRange(getResult(), resultRange);
}
126 changes: 126 additions & 0 deletions mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"

#include "mlir/IR/AffineExpr.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"

Expand Down Expand Up @@ -768,3 +769,128 @@ mlir::intrange::inferShapedDimOpInterface(ShapedDimOpInterface op,
}
return result.value_or(ConstantIntRanges::fromSigned(zero, typeMax));
}

//===----------------------------------------------------------------------===//
// Affine expression inference
//===----------------------------------------------------------------------===//

static ConstantIntRanges clampToPositive(const ConstantIntRanges &val) {
unsigned width = val.smin().getBitWidth();
APInt one(width, 1);
APInt clampedUMin = val.umin().ult(one) ? one : val.umin();
APInt clampedSMin = val.smin().slt(one) ? one : val.smin();
return ConstantIntRanges::fromUnsigned(clampedUMin, val.umax())
.intersection(ConstantIntRanges::fromSigned(clampedSMin, val.smax()));
}

ConstantIntRanges
mlir::intrange::inferAffineExpr(AffineExpr expr,
ArrayRef<ConstantIntRanges> dimRanges,
ArrayRef<ConstantIntRanges> symbolRanges) {
switch (expr.getKind()) {
case AffineExprKind::Constant: {
auto constExpr = cast<AffineConstantExpr>(expr);
APInt value(indexMaxWidth, constExpr.getValue(), /*isSigned=*/true);
return ConstantIntRanges::constant(value);
}
case AffineExprKind::DimId: {
auto dimExpr = cast<AffineDimExpr>(expr);
unsigned pos = dimExpr.getPosition();
assert(pos < dimRanges.size() && "Dimension index out of bounds");
return dimRanges[pos];
}
case AffineExprKind::SymbolId: {
auto symbolExpr = cast<AffineSymbolExpr>(expr);
unsigned pos = symbolExpr.getPosition();
assert(pos < symbolRanges.size() && "Symbol index out of bounds");
return symbolRanges[pos];
}
case AffineExprKind::Add: {
auto binExpr = cast<AffineBinaryOpExpr>(expr);
ConstantIntRanges lhs =
inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
ConstantIntRanges rhs =
inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
return inferAdd({lhs, rhs}, OverflowFlags::Nsw);
}
case AffineExprKind::Mul: {
auto binExpr = cast<AffineBinaryOpExpr>(expr);
ConstantIntRanges lhs =
inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
ConstantIntRanges rhs =
inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
return inferMul({lhs, rhs}, OverflowFlags::Nsw);
}
case AffineExprKind::Mod: {
auto binExpr = cast<AffineBinaryOpExpr>(expr);
ConstantIntRanges lhs =
inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
ConstantIntRanges rhs =
inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
// Affine mod is Euclidean modulo: result is always in [0, rhs-1].
// This assumes RHS is positive (enforced by affine expr semantics).
const APInt &lhsMin = lhs.smin();
const APInt &lhsMax = lhs.smax();
const APInt &rhsMin = rhs.smin();
const APInt &rhsMax = rhs.smax();
unsigned width = rhsMin.getBitWidth();

// Guard against division by zero.
if (rhsMax.isZero())
return ConstantIntRanges::maxRange(width);

APInt zero = APInt::getZero(width);

// For Euclidean mod, result is in [0, max(rhs)-1].
APInt umin = zero;
APInt umax = rhsMax - 1;

// Special case: if dividend is already in [0, min(rhs)), result equals
// dividend. We use rhsMin to ensure this is safe for all possible divisor
// values.
if (rhsMin.isStrictlyPositive() && lhsMin.isNonNegative() &&
lhsMax.ult(rhsMin)) {
umin = lhsMin;
umax = lhsMax;
}
// Special case: sweeping out a contiguous range with constant divisor.
// Only applies when dividend is non-negative to ensure result range is
// contiguous.
else if (rhsMin == rhsMax && lhsMin.isNonNegative() &&
(lhsMax - lhsMin).ult(rhsMax)) {
// For non-negative dividends, Euclidean mod is same as unsigned
// remainder.
umin = lhsMin.urem(rhsMax);
umax = lhsMax.urem(rhsMax);
// Result should be contiguous since we're not wrapping around.
assert(umin.ule(umax) &&
"Range should be contiguous for non-negative dividend");
}

return ConstantIntRanges::fromUnsigned(umin, umax);
}
case AffineExprKind::FloorDiv: {
auto binExpr = cast<AffineBinaryOpExpr>(expr);
ConstantIntRanges lhs =
inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
ConstantIntRanges rhs =
inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
// Affine floordiv requires strictly positive divisor (> 0).
// Clamp divisor lower bound to 1 for tighter range inference.
ConstantIntRanges clampedRhs = clampToPositive(rhs);
return inferFloorDivS({lhs, clampedRhs});
}
case AffineExprKind::CeilDiv: {
auto binExpr = cast<AffineBinaryOpExpr>(expr);
ConstantIntRanges lhs =
inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
ConstantIntRanges rhs =
inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
// Affine ceildiv requires strictly positive divisor (> 0).
// Clamp divisor lower bound to 1 for tighter range inference.
ConstantIntRanges clampedRhs = clampToPositive(rhs);
return inferCeilDivS({lhs, clampedRhs});
}
}
llvm_unreachable("unknown affine expression kind");
}
161 changes: 161 additions & 0 deletions mlir/test/Dialect/Affine/int-range-interface.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// RUN: mlir-opt --int-range-optimizations %s | FileCheck %s

// CHECK-LABEL: func @affine_apply_constant
// CHECK: test.reflect_bounds {smax = 42 : index, smin = 42 : index, umax = 42 : index, umin = 42 : index}
func.func @affine_apply_constant() -> index {
%0 = affine.apply affine_map<() -> (42)>()
%1 = test.reflect_bounds %0 : index
func.return %1 : index
}

// CHECK-LABEL: func @affine_apply_add
// CHECK: test.reflect_bounds {smax = 15 : index, smin = 6 : index, umax = 15 : index, umin = 6 : index}
func.func @affine_apply_add() -> index {
%d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
smin = 2 : index, smax = 5 : index } : index
%d1 = test.with_bounds { umin = 4 : index, umax = 10 : index,
smin = 4 : index, smax = 10 : index } : index
%0 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%d0, %d1)
%1 = test.reflect_bounds %0 : index
func.return %1 : index
}

// CHECK-LABEL: func @affine_apply_mul
// CHECK: test.reflect_bounds {smax = 30 : index, smin = 12 : index, umax = 30 : index, umin = 12 : index}
func.func @affine_apply_mul() -> index {
%d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
smin = 2 : index, smax = 5 : index } : index
%s0 = test.with_bounds { umin = 6 : index, umax = 6 : index,
smin = 6 : index, smax = 6 : index } : index
%0 = affine.apply affine_map<(d0)[s0] -> (d0 * s0)>(%d0)[%s0]
%1 = test.reflect_bounds %0 : index
func.return %1 : index
}

// CHECK-LABEL: func @affine_apply_floordiv
// CHECK: test.reflect_bounds {smax = 2 : index, smin = 1 : index, umax = 2 : index, umin = 1 : index}
func.func @affine_apply_floordiv() -> index {
%d0 = test.with_bounds { umin = 5 : index, umax = 10 : index,
smin = 5 : index, smax = 10 : index } : index
%0 = affine.apply affine_map<(d0) -> (d0 floordiv 4)>(%d0)
%1 = test.reflect_bounds %0 : index
func.return %1 : index
}

// CHECK-LABEL: func @affine_apply_ceildiv
// CHECK: test.reflect_bounds {smax = 3 : index, smin = 2 : index, umax = 3 : index, umin = 2 : index}
func.func @affine_apply_ceildiv() -> index {
%d0 = test.with_bounds { umin = 5 : index, umax = 10 : index,
smin = 5 : index, smax = 10 : index } : index
%0 = affine.apply affine_map<(d0) -> (d0 ceildiv 4)>(%d0)
%1 = test.reflect_bounds %0 : index
func.return %1 : index
}

// CHECK-LABEL: func @affine_apply_mod
// CHECK: test.reflect_bounds {smax = 3 : index, smin = 0 : index, umax = 3 : index, umin = 0 : index}
func.func @affine_apply_mod() -> index {
%d0 = test.with_bounds { umin = 5 : index, umax = 27 : index,
smin = 5 : index, smax = 27 : index } : index
%0 = affine.apply affine_map<(d0) -> (d0 mod 4)>(%d0)
%1 = test.reflect_bounds %0 : index
func.return %1 : index
}

// CHECK-LABEL: func @affine_apply_complex
// CHECK: test.reflect_bounds {smax = 13 : index, smin = 5 : index, umax = 13 : index, umin = 5 : index}
func.func @affine_apply_complex() -> index {
%d0 = test.with_bounds { umin = 10 : index, umax = 20 : index,
smin = 10 : index, smax = 20 : index } : index
%d1 = test.with_bounds { umin = 3 : index, umax = 7 : index,
smin = 3 : index, smax = 7 : index } : index
// (d0 floordiv 2) + (d1 mod 4) = [5, 10] + [0, 3] = [5, 13]
%0 = affine.apply affine_map<(d0, d1) -> (d0 floordiv 2 + d1 mod 4)>(%d0, %d1)
%1 = test.reflect_bounds %0 : index
func.return %1 : index
}

// CHECK-LABEL: func @affine_apply_with_symbols
// CHECK: test.reflect_bounds {smax = 24 : index, smin = 9 : index, umax = 24 : index, umin = 9 : index}
func.func @affine_apply_with_symbols() -> index {
%d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
smin = 2 : index, smax = 5 : index } : index
%s0 = test.with_bounds { umin = 3 : index, umax = 4 : index,
smin = 3 : index, smax = 4 : index } : index
// d0 * s0 + s0 = s0 * (d0 + 1) = [3, 4] * [3, 6] = [9, 24]
%0 = affine.apply affine_map<(d0)[s0] -> (d0 * s0 + s0)>(%d0)[%s0]
%1 = test.reflect_bounds %0 : index
func.return %1 : index
}

// CHECK-LABEL: func @affine_apply_sub
// CHECK: test.reflect_bounds {smax = 1 : index, smin = -8 : index
func.func @affine_apply_sub() -> index {
%d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
smin = 2 : index, smax = 5 : index } : index
%d1 = test.with_bounds { umin = 4 : index, umax = 10 : index,
smin = 4 : index, smax = 10 : index } : index
// d0 - d1 = [2, 5] - [4, 10] = [2-10, 5-4] = [-8, 1]
%0 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%d0, %d1)
%1 = test.reflect_bounds %0 : index
func.return %1 : index
}

// CHECK-LABEL: func @affine_apply_mul_constant
// CHECK: test.reflect_bounds {smax = 20 : index, smin = 8 : index, umax = 20 : index, umin = 8 : index}
func.func @affine_apply_mul_constant() -> index {
%d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
smin = 2 : index, smax = 5 : index } : index
// d0 * 4 = [2, 5] * 4 = [8, 20]
%0 = affine.apply affine_map<(d0) -> (d0 * 4)>(%d0)
%1 = test.reflect_bounds %0 : index
func.return %1 : index
}

// CHECK-LABEL: func @affine_apply_mod_small_range
// CHECK: test.reflect_bounds {smax = 2 : index, smin = 1 : index, umax = 2 : index, umin = 1 : index}
func.func @affine_apply_mod_small_range() -> index {
%d0 = test.with_bounds { umin = 5 : index, umax = 6 : index,
smin = 5 : index, smax = 6 : index } : index
// Small range optimization: 5 mod 4 = 1, 6 mod 4 = 2, so [1, 2]
%0 = affine.apply affine_map<(d0) -> (d0 mod 4)>(%d0)
%1 = test.reflect_bounds %0 : index
func.return %1 : index
}

// CHECK-LABEL: func @affine_apply_mod_already_in_range
// CHECK: test.reflect_bounds {smax = 7 : index, smin = 5 : index, umax = 7 : index, umin = 5 : index}
func.func @affine_apply_mod_already_in_range() -> index {
%d0 = test.with_bounds { umin = 5 : index, umax = 7 : index,
smin = 5 : index, smax = 7 : index } : index
// Dividend [5, 7] already in [0, 10), result equals dividend: [5, 7]
%0 = affine.apply affine_map<(d0) -> (d0 mod 10)>(%d0)
%1 = test.reflect_bounds %0 : index
func.return %1 : index
}

// CHECK-LABEL: func @affine_apply_mod_variable_divisor
// CHECK: test.reflect_bounds {smax = 4 : index, smin = 0 : index, umax = 4 : index, umin = 0 : index}
func.func @affine_apply_mod_variable_divisor() -> index {
%d0 = test.with_bounds { umin = 10 : index, umax = 20 : index,
smin = 10 : index, smax = 20 : index } : index
%s0 = test.with_bounds { umin = 3 : index, umax = 5 : index,
smin = 3 : index, smax = 5 : index } : index
// s0 can be 3, 4, or 5, so result is [0, max(s0)-1] = [0, 4]
%0 = affine.apply affine_map<(d0)[s0] -> (d0 mod s0)>(%d0)[%s0]
%1 = test.reflect_bounds %0 : index
func.return %1 : index
}

// CHECK-LABEL: func @affine_apply_mod_negative_dividend
// CHECK: test.reflect_bounds {smax = 3 : index, smin = 0 : index, umax = 3 : index, umin = 0 : index}
func.func @affine_apply_mod_negative_dividend() -> index {
%d0 = test.with_bounds { umin = 0 : index, umax = 2 : index,
smin = -2 : index, smax = 2 : index } : index
// Negative dividend: signed range [-2, 2] mod 4
// Actual results: -2->2, -1->3, 0->0, 1->1, 2->2 (Euclidean mod)
// Range is NOT contiguous, so we return conservative [0, 3]
%0 = affine.apply affine_map<(d0) -> (d0 mod 4)>(%d0)
%1 = test.reflect_bounds %0 : index
func.return %1 : index
}