diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel index f8bedfcc836c..997dfe0f5742 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel @@ -43,6 +43,7 @@ iree_compiler_cc_library( "StripDebugOps.cpp", "TestConversion.cpp", "TestFloatRangeAnalysis.cpp", + "TestIntegerDivisibilityAnalysis.cpp", "VerifyInitializationOrder.cpp", "VerifyStructuredControlFlow.cpp", ], diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt index f73fa2a9bca9..e64d47b6ab6d 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt @@ -41,6 +41,7 @@ iree_cc_library( "StripDebugOps.cpp" "TestConversion.cpp" "TestFloatRangeAnalysis.cpp" + "TestIntegerDivisibilityAnalysis.cpp" "VerifyInitializationOrder.cpp" "VerifyStructuredControlFlow.cpp" DEPS diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h index 5445337ddf1a..188c5457b7d0 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h @@ -73,6 +73,7 @@ createHoistIntoGlobalsPass(const ExprHoistingOptions &options); #define GEN_PASS_DECL_STRIPDEBUGOPSPASS #define GEN_PASS_DECL_TESTCONVERSIONPASS #define GEN_PASS_DECL_TESTFLOATRANGEANALYSISPASS +#define GEN_PASS_DECL_TESTINTEGERDIVISIBILITYANALYSISPASS #define GEN_PASS_DECL_VERIFYINITIALIZATIONORDERPASS #define GEN_PASS_DECL_VERIFYSTRUCTUREDCONTROLFLOWPASS #include "iree/compiler/Dialect/Util/Transforms/Passes.h.inc" // IWYU pragma: keep diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td index 7093bed69d38..b3f46f78add6 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td @@ -346,4 +346,14 @@ def TestFloatRangeAnalysisPass : Pass<"iree-util-test-float-range-analysis", ""> }]; } +def TestIntegerDivisibilityAnalysisPass : + Pass<"iree-util-test-integer-divisibility-analysis", ""> { + let summary = "Tests integer divisibility analysis."; + let description = [{ + Tests integer divisibility analysis by evaluating any + 'iree_unregistered.test_int_divisibility' op and setting the results on an + attribute. + }]; +} + #endif // IREE_DIALECT_UTIL_PASSES diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp new file mode 100644 index 000000000000..21954d4ec0dc --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp @@ -0,0 +1,68 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h" +#include "iree/compiler/Dialect/Util/Transforms/Passes.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" + +namespace mlir::iree_compiler::IREE::Util { + +#define GEN_PASS_DEF_TESTINTEGERDIVISIBILITYANALYSISPASS +#include "iree/compiler/Dialect/Util/Transforms/Passes.h.inc" + +namespace { + +class TestIntegerDivisibilityAnalysisPass + : public impl::TestIntegerDivisibilityAnalysisPassBase< + TestIntegerDivisibilityAnalysisPass> { +public: + void runOnOperation() override { + Operation *rootOp = getOperation(); + MLIRContext *context = &getContext(); + + // The pass is rooted on `iree_unregistered.test_int_divisibility` ops, + // which are expected to have a single operand for which to annotate + // divisibility information. + SmallVector> queryOps; + rootOp->walk([&](Operation *op) { + if (op->getName().getStringRef() == + "iree_unregistered.test_int_divisibility" && + op->getNumOperands() == 1) { + queryOps.emplace_back(op, op->getOperand(0)); + } + }); + + DataFlowSolver solver; + // DeadCodeAnalysis is the base analysis that allows the solver to traverse + // control flow. We include it to make the divisibility analysis more + // powerful. + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(rootOp))) { + return signalPassFailure(); + } + + for (auto &[op, value] : queryOps) { + auto *lattice = solver.lookupState(value); + if (!lattice || lattice->getValue().isUninitialized()) { + op->setAttr("divisibility", StringAttr::get(context, "uninitialized")); + continue; + } + + // Format for the divisibility information is "udiv = X, sdiv = Y". + const auto &div = lattice->getValue().getValue(); + std::string result; + llvm::raw_string_ostream os(result); + os << "udiv = " << div.udiv() << ", sdiv = " << div.sdiv(); + op->setAttr("divisibility", StringAttr::get(context, os.str())); + } + } +}; + +} // namespace + +} // namespace mlir::iree_compiler::IREE::Util diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel index d3fe86862e8b..64bb4a41f2c0 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel @@ -41,6 +41,7 @@ iree_lit_test_suite( "strip_debug_ops.mlir", "test_float_range_analysis.mlir", "test_float_range_analysis_linalg.mlir", + "test_integer_divisibility_analysis.mlir", "verify_initialization_order.mlir", "verify_structured_control_flow.mlir", ], diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt index abca38549966..658f9a9582f3 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt @@ -39,6 +39,7 @@ iree_lit_test_suite( "strip_debug_ops.mlir" "test_float_range_analysis.mlir" "test_float_range_analysis_linalg.mlir" + "test_integer_divisibility_analysis.mlir" "verify_initialization_order.mlir" "verify_structured_control_flow.mlir" TOOLS diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir new file mode 100644 index 000000000000..998b6f9a5592 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir @@ -0,0 +1,188 @@ +// RUN: iree-opt --split-input-file --iree-util-test-integer-divisibility-analysis --allow-unregistered-dialect %s | FileCheck %s + +// CHECK-LABEL: @affine_apply_mul_divisibility +util.func @affine_apply_mul_divisibility(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 * 4)>(%0) + // CHECK: divisibility = "udiv = 32, sdiv = 32" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_mul_negative +util.func @affine_apply_mul_negative(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 * -4)>(%0) + // CHECK: divisibility = "udiv = 32, sdiv = 32" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_add_gcd +util.func @affine_apply_add_gcd(%arg0 : index, %arg1 : index) -> index { + %0:2 = util.assume.int %arg0, + %arg1 : index, index + %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%0#0, %0#1) + // CHECK: divisibility = "udiv = 8, sdiv = 8" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_floordiv_exact +util.func @affine_apply_floordiv_exact(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 floordiv 4)>(%0) + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_ceildiv_exact +util.func @affine_apply_ceildiv_exact(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 ceildiv 4)>(%0) + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_floordiv_non_exact +util.func @affine_apply_floordiv_non_exact(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 floordiv 3)>(%0) + // CHECK: divisibility = "udiv = 1, sdiv = 1" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_mod +util.func @affine_apply_mod(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%0) + // CHECK: divisibility = "udiv = 1, sdiv = 1" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_composition +util.func @affine_apply_composition(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 * 4 + 16)>(%0) + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_with_symbol +util.func @affine_apply_with_symbol(%arg0 : index, %arg1 : index) -> index { + %0:2 = util.assume.int %arg0, + %arg1 : index, index + %1 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%0#0)[%0#1] + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_min_uniform_divisibility +util.func @affine_min_uniform_divisibility(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.min affine_map<(d0) -> (d0, d0 + 64)>(%0) + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_min_different_divisibilities +util.func @affine_min_different_divisibilities(%arg0 : index, %arg1 : index) -> index { + %0:2 = util.assume.int %arg0, + %arg1 : index, index + %1 = affine.min affine_map<(d0, d1) -> (d0, d1)>(%0#0, %0#1) + // CHECK: divisibility = "udiv = 8, sdiv = 8" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_max_uniform_divisibility +util.func @affine_max_uniform_divisibility(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.max affine_map<(d0) -> (d0, d0 - 64)>(%0) + // CHECK: divisibility = "udiv = 32, sdiv = 32" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_max_different_divisibilities +util.func @affine_max_different_divisibilities(%arg0 : index, %arg1 : index, %arg2 : index) -> index { + %0:3 = util.assume.int %arg0, + %arg1, + %arg2 : index, index, index + %3 = affine.max affine_map<(d0, d1, d2) -> (d0, d1, d2)>(%0#0, %0#1, %0#2) + // CHECK: divisibility = "udiv = 6, sdiv = 6" + %4 = "iree_unregistered.test_int_divisibility"(%3) : (index) -> index + util.return %4 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_constant +util.func @affine_apply_constant() -> index { + %0 = affine.apply affine_map<() -> (64)>() + // CHECK: divisibility = "udiv = 64, sdiv = 64" + %1 = "iree_unregistered.test_int_divisibility"(%0) : (index) -> index + util.return %1 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_chained_operations +util.func @affine_apply_chained_operations(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 * 8)>(%0) + %2 = affine.apply affine_map<(d0) -> (d0 + 16)>(%1) + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %3 = "iree_unregistered.test_int_divisibility"(%2) : (index) -> index + util.return %3 : index +} + +// ----- + +// CHECK-LABEL: @complex_chained_affine_ops +util.func @complex_chained_affine_ops(%arg0 : index, %arg1 : index, %arg2 : index) -> index { + %0:3 = util.assume.int %arg0, + %arg1, + %arg2 : index, index, index + %1 = affine.apply affine_map<(d0, d1) -> (d0 + 2 * d1)>(%0#0, %0#1) + // CHECK: divisibility = "udiv = 14, sdiv = 14" + %div_1 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + %2 = affine.max affine_map<(d0, d1) -> (d0 floordiv 6, d1 * 3)>(%0#0, %0#2) + // CHECK: divisibility = "udiv = 5, sdiv = 5" + %div_2 = "iree_unregistered.test_int_divisibility"(%2) : (index) -> index + %3 = affine.min affine_map<(d0)[s0] -> (2 * (s0 * d0 - 14) ceildiv 7, d0 floordiv 3 * 2)>(%2)[%1] + // CHECK: divisibility = "udiv = 2, sdiv = 2" + %div_3 = "iree_unregistered.test_int_divisibility"(%3) : (index) -> index + util.return %div_3 : index +} diff --git a/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel b/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel index 4bb133040940..61eafdf3e540 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel +++ b/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel @@ -42,6 +42,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/TensorExt/IR", "//compiler/src/iree/compiler/Dialect/Util/IR", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:BufferizationInterfaces", "@llvm-project//mlir:ControlFlowInterfaces", diff --git a/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt b/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt index a183c5539427..efb4db25dbf3 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt +++ b/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt @@ -31,6 +31,7 @@ iree_cc_library( "UtilExternalModels.cpp" DEPS LLVMSupport + MLIRAffineDialect MLIRArithDialect MLIRControlFlowInterfaces MLIRGPUDialect diff --git a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp index 75d442ba463d..78e79aa5efc2 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp +++ b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp @@ -16,11 +16,13 @@ #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -49,6 +51,232 @@ getDivisibilityOfOperand(Value v, return IREE::Util::ConstantIntDivisibility(1, 1); } +/// Visits affine expressions and recursively calculates the divisibilities of +/// each subexpression. The final divisibilities of the expression and its +/// subexpressions will be stored in the map for which a reference is provided +/// to the AffineExprDivisibilityFinder (i.e., `divisibilityMap`). +class AffineExprDivisibilityFinder + : public AffineExprVisitor { +public: + using ExprDivisibilityMap = + llvm::DenseMap; + AffineExprDivisibilityFinder(ExprDivisibilityMap &divisibilityMap) + : divisibilityMap(divisibilityMap) {} + + IREE::Util::ConstantIntDivisibility + visitConstantExpr(AffineConstantExpr expr) { + // Constant expressions are trivial, since they are always static. + uint64_t constValue = std::abs(expr.getValue()); + return IREE::Util::ConstantIntDivisibility(constValue, constValue); + } + + IREE::Util::ConstantIntDivisibility visitDimExpr(AffineDimExpr expr) { + // Dim expressions cannot be analyzed further, so return the divisibility + // in `divisibilityMap` if it has been populated by the caller, or fallback + // to the minimum divisibility. + if (divisibilityMap.contains(expr)) { + return divisibilityMap[expr]; + } + return IREE::Util::IntegerDivisibility::getMinDivisibility().getValue(); + } + + IREE::Util::ConstantIntDivisibility visitSymbolExpr(AffineSymbolExpr expr) { + // Symbol expressions cannot be analyzed further, so return the divisibility + // in `divisibilityMap` if it has been populated by the caller, or fallback + // to the minimum divisibility. + if (divisibilityMap.contains(expr)) { + return divisibilityMap[expr]; + } + return IREE::Util::IntegerDivisibility::getMinDivisibility().getValue(); + } + + /// Infer the divisibility of an addition or subtraction expression by + /// recursively visiting the LHS and RHS, and then unioning the results. + IREE::Util::ConstantIntDivisibility visitAddExpr(AffineBinaryOpExpr expr) { + if (divisibilityMap.contains(expr)) { + return divisibilityMap[expr]; + } + // The divisibility of an addition is the GCD of its constituents' + // divisibilities. + IREE::Util::ConstantIntDivisibility lhsDiv = visit(expr.getLHS()); + IREE::Util::ConstantIntDivisibility rhsDiv = visit(expr.getRHS()); + return lhsDiv.getUnion(rhsDiv); + } + + /// Infer the divisibility of a multiplication expression by recursively + /// visiting the LHS and RHS, and then multiplying the results. + IREE::Util::ConstantIntDivisibility visitMulExpr(AffineBinaryOpExpr expr) { + if (divisibilityMap.contains(expr)) { + return divisibilityMap[expr]; + } + // The divisibility of a multiplication is the product of its constituents' + // divisibilities. + IREE::Util::ConstantIntDivisibility lhsDiv = visit(expr.getLHS()); + IREE::Util::ConstantIntDivisibility rhsDiv = visit(expr.getRHS()); + return IREE::Util::ConstantIntDivisibility(lhsDiv.udiv() * rhsDiv.udiv(), + lhsDiv.sdiv() * rhsDiv.sdiv()); + } + + IREE::Util::ConstantIntDivisibility + visitFloorDivExpr(AffineBinaryOpExpr expr) { + return visitDivExpr(expr); + } + + IREE::Util::ConstantIntDivisibility + visitCeilDivExpr(AffineBinaryOpExpr expr) { + return visitDivExpr(expr); + } + + /// Mod expressions could be inferred to be zero in some cases, but for now + /// just return the minimum divisibility. + /// TODO(Max191): Handle evenly divisible cases, and ensure that the zero + /// divisibility propagates properly through parent expressions. + IREE::Util::ConstantIntDivisibility visitModExpr(AffineBinaryOpExpr expr) { + return visitInvalidExpr(expr); + } + +private: + IREE::Util::ConstantIntDivisibility + visitInvalidExpr(AffineBinaryOpExpr expr) { + return IREE::Util::IntegerDivisibility::getMinDivisibility().getValue(); + } + + /// Helper shared by ceildiv and floordiv implementations. Returns the minimum + /// divisibility as a fallback if the divisor is not a constant, because the + /// divisibility cannot be inferred in this case. If the divisor is a + /// constant, then this function recursively visits the dividend, and returns + /// the quotient of the dividend's divisibility with the divisor. + IREE::Util::ConstantIntDivisibility visitDivExpr(AffineBinaryOpExpr expr) { + if (divisibilityMap.contains(expr)) { + return divisibilityMap[expr]; + } + auto constRhs = dyn_cast(expr.getRHS()); + // Division by zero is undefined, so return the minimum divisibility. + if (!constRhs || constRhs.getValue() == 0) { + return IREE::Util::ConstantIntDivisibility(1, 1); + } + auto constValue = static_cast(std::abs(constRhs.getValue())); + IREE::Util::ConstantIntDivisibility lhsDiv = visit(expr.getLHS()); + uint64_t divUDiv = + lhsDiv.udiv() % constValue == 0 ? lhsDiv.udiv() / constValue : 1; + uint64_t divSDiv = + lhsDiv.sdiv() % constValue == 0 ? lhsDiv.sdiv() / constValue : 1; + return IREE::Util::ConstantIntDivisibility(divUDiv, divSDiv); + } + + ExprDivisibilityMap &divisibilityMap; +}; + +/// Returns the divisibilities of each AffineMap result based on the +/// divisibilities of its dims and symbols. The `dimAndSymbolDivisibilities` +/// should contain the divisibilities of the dims, followed by the +/// divisibilities of the symbols in ascending order by their positions. +static SmallVector getResultDivisibilities( + AffineMap map, + ArrayRef dimAndSymbolDivisibilities) { + // Seed the AffineExprDivisibilityFinder with the dimAndSymbolDivisibilities. + llvm::DenseMap + exprDivisibilityMap; + SmallVector inputExprs; + inputExprs.append(llvm::map_to_vector( + llvm::seq(map.getNumDims()), + [&](int64_t dim) { return getAffineDimExpr(dim, map.getContext()); })); + inputExprs.append(llvm::map_to_vector( + llvm::seq(map.getNumSymbols()), + [&](int64_t sym) { return getAffineSymbolExpr(sym, map.getContext()); })); + for (auto [expr, divisibility] : + llvm::zip_equal(inputExprs, dimAndSymbolDivisibilities)) { + exprDivisibilityMap[expr] = divisibility; + } + AffineExprDivisibilityFinder divisibilityFinder(exprDivisibilityMap); + + // Walk each result expression and compute their divisibilities. + SmallVector resultDivisibilities; + for (AffineExpr resultExpr : map.getResults()) { + resultDivisibilities.push_back(divisibilityFinder.visit(resultExpr)); + } + return resultDivisibilities; +} + +struct AffineApplyInferIntDivisibilityOpInterface + : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel< + AffineApplyInferIntDivisibilityOpInterface, affine::AffineApplyOp> { + + void inferResultDivisibility( + Operation *op, ArrayRef argDivs, + IREE::Util::SetIntDivisibilityFn setResultDivs) const { + auto affineApplyOp = cast(op); + SmallVector operandDivisibilities; + for (auto [operand, divisibility] : + llvm::zip(affineApplyOp.getOperands(), argDivs)) { + operandDivisibilities.push_back( + getDivisibilityOfOperand(operand, divisibility)); + } + + SmallVector resultDivisibilities = + getResultDivisibilities(affineApplyOp.getMap(), operandDivisibilities); + for (auto [result, divisibility] : + llvm::zip_equal(affineApplyOp->getResults(), resultDivisibilities)) { + setResultDivs(result, divisibility); + } + } +}; + +/// Infer the result divisibility of an affine.min or affine.max operation +/// based on its operand divisibilities. The result divisibility is the GCD +/// of the divisibilities of each of the affine map results, because the result +/// of the affine.min/max op could be any of these results. +template +static void inferAffineMinOrMaxResultDivisibility( + MinOrMaxTy minOrMaxOp, ArrayRef argDivs, + IREE::Util::SetIntDivisibilityFn setResultDivs) { + static_assert( + llvm::is_one_of::value, + "MinOrMaxTy must be affine::AffineMinOp or affine::AffineMaxOp"); + SmallVector operandDivisibilities; + for (auto [operand, divisibility] : + llvm::zip(minOrMaxOp.getOperands(), argDivs)) { + operandDivisibilities.push_back( + getDivisibilityOfOperand(operand, divisibility)); + } + + SmallVector resultDivisibilities = + getResultDivisibilities(minOrMaxOp.getMap(), operandDivisibilities); + + IREE::Util::ConstantIntDivisibility resultDivisibility = + resultDivisibilities.pop_back_val(); + for (auto divisibility : resultDivisibilities) { + resultDivisibility = resultDivisibility.getUnion(divisibility); + } + setResultDivs(minOrMaxOp.getResult(), resultDivisibility); +} + +struct AffineMinInferIntDivisibilityOpInterface + : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel< + AffineMinInferIntDivisibilityOpInterface, affine::AffineMinOp> { + + void inferResultDivisibility( + Operation *op, ArrayRef argDivs, + IREE::Util::SetIntDivisibilityFn setResultDivs) const { + auto affineMinOp = cast(op); + inferAffineMinOrMaxResultDivisibility(affineMinOp, argDivs, setResultDivs); + } +}; + +struct AffineMaxInferIntDivisibilityOpInterface + : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel< + AffineMaxInferIntDivisibilityOpInterface, affine::AffineMaxOp> { + + void inferResultDivisibility( + Operation *op, ArrayRef argDivs, + IREE::Util::SetIntDivisibilityFn setResultDivs) const { + auto affineMaxOp = cast(op); + inferAffineMinOrMaxResultDivisibility(affineMaxOp, argDivs, setResultDivs); + } +}; + struct ArithConstantInferIntDivisibilityOpInterface : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel< ArithConstantInferIntDivisibilityOpInterface, arith::ConstantOp> { @@ -104,8 +332,13 @@ struct ArithDivUIInferIntDivisibilityOpInterface auto lhsDivisibility = getDivisibilityOfOperand(divOp.getLhs(), argDivs[0]); - uint64_t divUDiv = lhsDivisibility.udiv() / intVal.getZExtValue(); - uint64_t divSDiv = lhsDivisibility.sdiv() / std::abs(intVal.getSExtValue()); + uint64_t divUDiv = lhsDivisibility.udiv() % intVal.getZExtValue() == 0 + ? lhsDivisibility.udiv() / intVal.getZExtValue() + : 1; + uint64_t divSDiv = + lhsDivisibility.sdiv() % std::abs(intVal.getSExtValue()) == 0 + ? lhsDivisibility.sdiv() / std::abs(intVal.getSExtValue()) + : 1; setResultDivs(divOp, IREE::Util::ConstantIntDivisibility(divUDiv, divSDiv)); } @@ -906,6 +1139,7 @@ struct SCFIndexSwitchOpMutableRegionBranchOpInterface void registerUtilExternalModels(DialectRegistry ®istry) { // Must ensure that any dependent dialects are registered. + registry.insert(); registry.insert(); registry.insert(); registry.insert(); @@ -932,6 +1166,16 @@ void registerUtilExternalModels(DialectRegistry ®istry) { *context); }); + registry.addExtension( + +[](MLIRContext *context, affine::AffineDialect *dialect) { + affine::AffineApplyOp::attachInterface< + AffineApplyInferIntDivisibilityOpInterface>(*context); + affine::AffineMinOp::attachInterface< + AffineMinInferIntDivisibilityOpInterface>(*context); + affine::AffineMaxOp::attachInterface< + AffineMaxInferIntDivisibilityOpInterface>(*context); + }); + registry.addExtension( +[](MLIRContext *context, tensor::TensorDialect *dialect) { tensor::InsertSliceOp::attachInterface(