diff --git a/paddle/cinn/common/integer_set.cc b/paddle/cinn/common/integer_set.cc index a4fa9ecbae1afe..efe54a192a7d2a 100644 --- a/paddle/cinn/common/integer_set.cc +++ b/paddle/cinn/common/integer_set.cc @@ -164,7 +164,8 @@ cas_intervals_t CollectVarIntervalsOfExprs(const std::vector& exprs, lower_bound = ir::Expr(1); } var_intervals.insert( - {var->name, CasInterval(lower_bound, upper_bound)}); + {var->name, + CasInterval(lower_bound, NormalizeUpperBound(upper_bound))}); } return false; }); @@ -572,6 +573,9 @@ class BoundReplacer : public ir::IRMutator<> { ir::Expr SymbolicExprAnalyzer::LowerBound(const ir::Expr& expr) const { BoundReplacer bound_replacer(var_intervals_, true); ir::Expr bound = ir::ir_utils::IRCopy(expr); + if (bound.is_index()) { + bound = bound.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3); + } bound_replacer(&bound); return optim::ArithSimplify(bound); } @@ -579,7 +583,11 @@ ir::Expr SymbolicExprAnalyzer::LowerBound(const ir::Expr& expr) const { ir::Expr SymbolicExprAnalyzer::UpperBound(const ir::Expr& expr) const { BoundReplacer bound_replacer(var_intervals_, false); ir::Expr bound = ir::ir_utils::IRCopy(expr); + if (bound.is_index()) { + bound = bound.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3); + } bound_replacer(&bound); + return optim::ArithSimplify(bound); } @@ -709,7 +717,8 @@ SingleIntervalIntSet::SingleIntervalIntSet(const ir::Expr& min, ? x->as_var()->upper_bound : SymbolicExprLimit::positive_inf; var_intervals_.insert( - {x->as_var()->name, CasInterval(lower_bound, upper_bound)}); + {x->as_var()->name, + CasInterval(lower_bound, NormalizeUpperBound(upper_bound))}); } return false; }; diff --git a/paddle/cinn/common/ir_util.cc b/paddle/cinn/common/ir_util.cc index 13793271f88c0a..1514c995126348 100644 --- a/paddle/cinn/common/ir_util.cc +++ b/paddle/cinn/common/ir_util.cc @@ -270,6 +270,16 @@ bool is_zero(Expr v) { return false; } +Expr NormalizeUpperBound(Expr upper_bound, bool minus_one /* = true */) { + if (upper_bound == SymbolicExprLimit::positive_inf) { + return upper_bound; + } + if (minus_one) { + return upper_bound - ir::Expr(1); // [lower, upper) to [lower, upper] + } + return upper_bound + ir::Expr(1); // (lower, upper] to [lower, upper) +} + Expr CastIfNeeded(Expr body, Type type) { if (body.type() == type) return body; return ir::Cast::Make(type, body); diff --git a/paddle/cinn/common/ir_util.h b/paddle/cinn/common/ir_util.h index 6eb9a9a6b1b88c..f758ed3db9b4c4 100644 --- a/paddle/cinn/common/ir_util.h +++ b/paddle/cinn/common/ir_util.h @@ -91,6 +91,8 @@ std::vector GatherItersToTensorProducer( bool is_zero(Expr v); +Expr NormalizeUpperBound(Expr upper_bound, bool minus_one = true); + bool MathEqual(const Expr &a, const Expr &b); //! helper function to get a ir::Select node. diff --git a/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.cc index 333846d6740568..2327d2f3aeeddd 100644 --- a/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.cc @@ -141,7 +141,8 @@ IntSet Evaluate(Expr expr, const std::unordered_map& var_domain) { Expr copy_for_upper_bound = ir::ir_utils::IRCopy(expr); Expr copy_for_lower_bound = ir::ir_utils::IRCopy(expr); - common::cas_intervals_t var_intervals; + common::cas_intervals_t + var_intervals; // variable name -> CasIntervals[lower_bound, upper_bound] std::vector var_vec = ir::ir_utils::CollectIRNodesWithoutTensor( expr, [](const ir::Expr* x) { return x->as_var(); }); for (Expr var_expr : var_vec) { @@ -150,7 +151,9 @@ IntSet Evaluate(Expr expr, const ir::Var& fixed_var = fixed.at(var); var_intervals.emplace( fixed_var->name, - common::CasInterval(fixed_var->lower_bound, fixed_var->upper_bound)); + common::CasInterval( + fixed_var->lower_bound, + cinn::common::NormalizeUpperBound(fixed_var->upper_bound))); optim::ReplaceVarWithExpr(©_for_lower_bound, var, Expr(fixed_var)); optim::ReplaceVarWithExpr(©_for_upper_bound, var, Expr(fixed_var)); } else if (var_domain.count(var) != 0) { @@ -172,7 +175,8 @@ IntSet Evaluate(Expr expr, ::common::errors::InvalidArgument( "The 'upper_bound' of the variable must be defined.")); optim::ReplaceVarWithExpr(©_for_lower_bound, var, var->lower_bound); - optim::ReplaceVarWithExpr(©_for_upper_bound, var, var->upper_bound); + optim::ReplaceVarWithExpr( + ©_for_upper_bound, var, NormalizeUpperBound(var->upper_bound)); } } ir::Expr lower_bound = optim::ArithSimplify(copy_for_lower_bound); diff --git a/paddle/cinn/ir/ir.h b/paddle/cinn/ir/ir.h index 7acf4e110cde2d..cf8b58cd6b57f7 100644 --- a/paddle/cinn/ir/ir.h +++ b/paddle/cinn/ir/ir.h @@ -421,6 +421,7 @@ struct _Var_ : public ExprNode<_Var_> { }; //! A named variable. +// i ∈ [lower_bound, upper_bound) struct Var : public IrNodeRef { Var() = default; explicit Var(IrNode* n) : IrNodeRef(n) {} @@ -846,6 +847,7 @@ struct For : public ExprNode, public ForBase { //! The minimum value of the iteration. Expr min; //! The extent of the iteration. + // loop_var ∈ [min, min + extent) Expr extent; Expr body; diff --git a/paddle/cinn/ir/ir_analyzer/ir_analyzer.cc b/paddle/cinn/ir/ir_analyzer/ir_analyzer.cc index 492738516e95a7..860d285b242aa6 100644 --- a/paddle/cinn/ir/ir_analyzer/ir_analyzer.cc +++ b/paddle/cinn/ir/ir_analyzer/ir_analyzer.cc @@ -621,7 +621,8 @@ std::vector IndicesToVars(const std::vector& indices) { if (e.is_constant()) { std::string var_name = cinn::UniqName("constant" + static_cast(e.get_constant())); - result.emplace_back(e, e, var_name, /* is_reduce = */ false); + result.emplace_back( + e, NormalizeUpperBound(e, false), var_name, /* is_reduce = */ false); } else if (e.As() != nullptr) { ir::Expr copy_e = ir::ir_utils::IRCopy(e); ir::_Var_* var_ref = copy_e.As(); @@ -635,14 +636,17 @@ std::vector IndicesToVars(const std::vector& indices) { ir::Var var = x->as_var_ref(); var_intervals.insert( {var->name, - common::CasInterval{var->lower_bound, var->upper_bound}}); + common::CasInterval{var->lower_bound, + NormalizeUpperBound(var->upper_bound)}}); if (var->is_reduce_axis) is_reduce = true; } return false; }); common::SymbolicExprAnalyzer analyzer(var_intervals); - result.emplace_back( - analyzer.LowerBound(e), analyzer.UpperBound(e), var_name, is_reduce); + result.emplace_back(analyzer.LowerBound(e), + NormalizeUpperBound(analyzer.UpperBound(e), false), + var_name, + is_reduce); } } return result; diff --git a/paddle/cinn/optim/ir_simplify.cc b/paddle/cinn/optim/ir_simplify.cc index f1b5d3dfc9f381..1457b61528976a 100644 --- a/paddle/cinn/optim/ir_simplify.cc +++ b/paddle/cinn/optim/ir_simplify.cc @@ -386,6 +386,296 @@ struct SimplifySelectMutator : public ir::IRMutator<> { } }; +/* +Example 1: + Select(a <= b, b, a) → max(a, b) +Example 2: + Select(a <= b, a, b) → min(a, b) +Example 3: + Select(a <= MAX, max(a, MIN), MAX) → min(max(a, MIN), MAX) + Select(a <= MAX, max(MIN, a), MAX) → min(max(a, MIN), MAX) +Example 4: + Select(MIN <= b, min(b, MAX), MIN) → max(min(b, MAX), MIN) + → min(max(b, MIN), MAX) + Select(MIN <= b, min(MAX, b), MIN) → max(min(b, MAX), MIN) + → min(max(b, MIN), MAX) +*/ +struct SimplifySelect2MinMaxMutator : public ir::ExprMutator<> { + void operator()(Expr* x) { ir::ExprMutator<>::Visit(x, x); } + + using ir::ExprMutator<>::Visit; + + // Recursively optimize CompareOp operands + template + void VisitCompare(T* op, Expr* expr) { + Expr a = op->a(); + Expr b = op->b(); + ir::ExprMutator<>::Visit(&a, &a); + ir::ExprMutator<>::Visit(&b, &b); + + if (a.get() != op->a().get() || b.get() != op->b().get()) { + *expr = T::Make(a, b); + } + } + + void Visit(const ir::GE* op, Expr* expr) override { VisitCompare(op, expr); } + void Visit(const ir::GT* op, Expr* expr) override { VisitCompare(op, expr); } + void Visit(const ir::LE* op, Expr* expr) override { VisitCompare(op, expr); } + void Visit(const ir::LT* op, Expr* expr) override { VisitCompare(op, expr); } + + void Visit(const Select* op, Expr* expr) override { + auto* node = expr->As(); + + // 1. Recursively optimize sub-expressions + Expr condition = node->condition; + Expr true_value = node->true_value; + Expr false_value = node->false_value; + + ir::ExprMutator<>::Visit(&condition, &condition); + ir::ExprMutator<>::Visit(&true_value, &true_value); + ir::ExprMutator<>::Visit(&false_value, &false_value); + + // 2. If sub-expressions are modified, rebuild the Select node + if (condition.get() != node->condition.get() || + true_value.get() != node->true_value.get() || + false_value.get() != node->false_value.get()) { + *expr = ir::Select::Make(condition, true_value, false_value); + node = expr->As(); + } + + // 3. Function to optimize Select into Min/Max when possible + auto TryOptimizeSelect = [&](const Expr& a, + const Expr& b, + const Expr& x, + const Expr& y) -> Expr { + // Case 1: Select(a <= b, b, a) → max(a, b) + if (x == b && y == a) { + if (b.is_constant()) { + return ir::Max::Make(a, b); + } else { + return ir::Max::Make(b, a); + } + } + // Case 2: Select(a <= b, a, b) → min(a, b) + if (x == a && y == b) { + if (b.is_constant()) { + return ir::Min::Make(a, b); + } else { + return ir::Min::Make(b, a); + } + } + // Case 3: Select(a <= MAX, max(a, MIN), MAX) → min(max(a, MIN), MAX) + if (auto* max = x.As()) { + if (max->a() == a) { + if (max->b().is_constant() && y.is_constant() && b.is_constant()) { + if (y.get_constant() == b.get_constant() && + (max->b()).get_constant() <= y.get_constant()) { + return ir::Min::Make(ir::Max::Make(a, max->b()), b); + } + } + } else if (max->b() == a) { + // Select(a <= MAX, max(MIN, a), MAX) → min(max(a, MIN), MAX) + if (max->a().is_constant() && y.is_constant() && b.is_constant()) { + if (y.get_constant() == b.get_constant() && + (max->a()).get_constant() <= y.get_constant()) { + return ir::Min::Make(ir::Max::Make(a, max->a()), b); + } + } + } + } + // Case 4: Select(MIN <= b, min(b, Max), MIN) → max(min(b, MAX), MIN) + // → min(max(b, MIN), MAX) + if (auto* min = x.As()) { + if (min->a() == b) { + if ((min->b()).is_constant() && y.is_constant() && a.is_constant()) { + if (y.get_constant() == a.get_constant() && + y.get_constant() <= (min->b()).get_constant()) { + return ir::Min::Make(ir::Max::Make(b, a), min->b()); + } + } + } else if (min->b() == b) { + // Select(MIN <= b, min(Max, b), MIN) → min(max(b, MIN), MAX) + if ((min->a()).is_constant() && y.is_constant() && a.is_constant()) { + if (y.get_constant() == a.get_constant() && + y.get_constant() <= (min->a()).get_constant()) { + return ir::Min::Make(ir::Max::Make(b, a), min->a()); + } + } + } + } + return Expr(nullptr); + }; + + // 4. Try to optimize different comparison conditions by converting them to + // <= logic + if (auto* ge = node->condition.As()) { + // Select(a >= b, t, f) → Select(b <= a, t, f) + Expr optimized = TryOptimizeSelect( + ge->b(), ge->a(), node->true_value, node->false_value); + if (optimized.defined()) { + *expr = optimized; + return; + } + } else if (auto* gt = node->condition.As()) { + // Select(a > b, t, f) → Select(a <= b, f, t) + Expr optimized = TryOptimizeSelect( + gt->a(), gt->b(), node->false_value, node->true_value); + if (optimized.defined()) { + *expr = optimized; + return; + } + } else if (auto* le = node->condition.As()) { + // Select(a <= b, t, f) → Select(a <= b, t, f) + Expr optimized = TryOptimizeSelect( + le->a(), le->b(), node->true_value, node->false_value); + if (optimized.defined()) { + *expr = optimized; + return; + } + } else if (auto* lt = node->condition.As()) { + // Select(a < b, t, f) → Select(b <= a, f, t) + Expr optimized = TryOptimizeSelect( + lt->b(), lt->a(), node->false_value, node->true_value); + if (optimized.defined()) { + *expr = optimized; + return; + } + } + } +}; + +// Optimizes pow(2.0f, ceil(log2(x))) pattern into more efficient bit +// manipulation: +// Original: pow(2.0f, ceil(log2(x))) +// Optimized: ldexpf(1.0f, exponent) where exponent is calculated via: +// 1. float_as_uint(x) - reinterpret float as uint32 +// 2. right_shift(bits, 23) - extract exponent field +// 3. (exponent_raw & 0xFF) - 127 - adjust IEEE754 bias +// 4. +1 if mantissa is non-zero (for ceil behavior) +struct SimplifyPowerCeilLog2BitOpLdexpfMutator : public ir::ExprMutator<> { + void operator()(Expr* expr) { ir::ExprMutator<>::Visit(expr, expr); } + + using ir::ExprMutator<>::Visit; + void Visit(const ir::Call* op, Expr* expr) override { + /// 1. First recursively process all sub-expressions + std::vector new_args; + for (const auto& arg : op->read_args) { + Expr new_arg = arg; + Visit(&new_arg, &new_arg); + new_args.push_back(new_arg); + } + + // 2. Match target pattern: pow(base, ceil(log2(x))) + if (op->name == "pow" && new_args.size() == 2) { + const Expr& base = new_args[0]; + const Expr& exponent = new_args[1]; + + // Check if exponent is ceil(log2(x)) + if (const ir::Call* ceil_call = exponent.As()) { + if (ceil_call->name == "ceil" && ceil_call->read_args.size() == 1) { + if (const ir::Call* log2_call = + ceil_call->read_args[0].As()) { + if (log2_call->name == "log2" && log2_call->read_args.size() == 1 && + log2_call->read_args[0].type().is_float(32)) { + /// Verify base is 2.0f for optimization + bool is_base_two = false; + if (base.is_constant()) { + if (base.get_constant() == 2.0f) { + is_base_two = true; + } + } + if (is_base_two) { + // 3. Replace with bit operations + ldexpf + Expr x = log2_call->read_args[0]; // Extract log2's argument + + // Create bit operations to compute ceil(log2(x)) + // (1) Reinterpret float as 32-bit integer + Expr bits = ir::Call::Make(common::Int(32), + "__float_as_uint", + {x}, + {}, + ir::CallType::Extern, + ir::FunctionRef(), + 0, + {}); + + std::vector shift_r_args = {bits, ir::Expr(23)}; + Expr shift_r = ir::Call::Make(common::Int(32), + "right_shift", + shift_r_args, + {}, + ir::CallType::Extern, + ir::FunctionRef(), + 0, + {}); + // (2) Extract exponent part: ((bits >> 23) & 0xFF) - 127 + std::vector bitwise_and_exp_args = { + shift_r, ir::Expr(0xFF)}; + Expr bitwise_and_exp = ir::Call::Make(common::Int(32), + "bitwise_and", + bitwise_and_exp_args, + {}, + ir::CallType::Extern, + ir::FunctionRef(), + 0, + {}); + Expr exponent_raw = + ir::Sub::Make(bitwise_and_exp, ir::Expr(127)); + // 3. Check if mantissa is non-zero (i.e., if exponent+1 is + // needed) + std::vector bitwise_and_tail_args = { + bits, ir::Expr(0x007FFFFF)}; + Expr bitwise_and_tail = ir::Call::Make(common::Int(32), + "bitwise_and", + bitwise_and_tail_args, + {}, + ir::CallType::Extern, + ir::FunctionRef(), + 0, + {}); + Expr mantissa_non_zero = + ir::NE::Make(bitwise_and_tail, ir::Expr(0)); + // (4) Check if it's a normal number (exponent != -127) + Expr is_normal = ir::NE::Make(exponent_raw, ir::Expr(-127)); + // (5) If needed, exponent += 1 + Expr exponent_final = ir::Add::Make( + exponent_raw, + ir::Select::Make( + ir::And::Make(is_normal, mantissa_non_zero), + ir::Expr(1), + ir::Expr(0))); + // (6) Create final expression: ldexpf(1.0f, exponent_final) + Expr new_expr = ir::Call::Make(op->type(), + "ldexpf", + {ir::Expr(1.0f), exponent_final}, + {}, + ir::CallType::Extern, + ir::FunctionRef(), + 0, + {}); + *expr = new_expr; + return; + } + } + } + } + } + } + + // For non-target patterns, reconstruct as-is + if (new_args != op->read_args) { + *expr = ir::Call::Make(op->type(), + op->name, + new_args, + op->write_args, + op->call_type, + op->func, + op->value_index, + op->attrs); + } + } +}; + struct SimplifyUnitBlockMutator : public ir::ExprMutator<> { void operator()(Expr* x) { ir::ExprMutator::Visit(x, x); } @@ -498,6 +788,8 @@ void Simplify(Expr* expr) { SimplifyLogicalMutator()(expr); SimplifyIfThenElseMutator()(expr); SimplifySelectMutator()(expr); + SimplifySelect2MinMaxMutator()(expr); + SimplifyPowerCeilLog2BitOpLdexpfMutator()(expr); SimplifyNoPureMathMutator()(expr); VLOG(6) << "End Simplify " << *expr; } diff --git a/paddle/cinn/optim/simplify_util.cc b/paddle/cinn/optim/simplify_util.cc index 0c02ff5ce9bb89..5fa37a3ccc3d01 100644 --- a/paddle/cinn/optim/simplify_util.cc +++ b/paddle/cinn/optim/simplify_util.cc @@ -677,8 +677,124 @@ std::optional> MatchPattern( return std::nullopt; } +/*! + * \brief Optimize linear division and modulo operations with constant + * denominators. + * + * This function handles linear expressions of the form + * `(a * C1 + b) / C2` and `(a * C1 + b) % C2` + * where C1 and C2 are constants. It specifically targets: + * 1. Linear combinations in the numerator (sums of terms) + * 2. Constant denominators + * + * The optimization: + * 1. Separates terms divisible by the denominator (linear coefficients) + * 2. Groups remaining terms as a remainder expression + * 3. For division: + * - Returns the sum of divisible terms if remainder < denominator + * - Otherwise preserves the original division + * 4. For modulo: + * - Returns the remainder if it's provably smaller than denominator + * - Otherwise preserves the original modulo + * + * Example linear optimizations: + * 1. Linear division: (x * 8 + y * 4 + 3) / 4 → x*2 + y + 0 (when 3 < 4) + * 2. Linear modulo: (x * 8 + y * 4 + 3) % 4 → 0 + 0 + 3 + * 3. Partial division: (x * 6 + 5) / 3 → x * 2 + 5 / 3 (when 5 >= 3) + * + * \param expr The linear division/modulo expression to optimize + * \param ana Symbolic analyzer for proving expression bounds + * \return Simplified expression if provably correct, original otherwise + */ +ir::IndexExpr HandleDivModWithConstants( + const ir::IndexExpr &expr, const common::SymbolicExprAnalyzer &ana) { + // Get numerator and denominator + auto numerator = expr.operand(0); + auto denominator = expr.operand(1); + + // Check if denominator is a constant + if (!denominator.is_constant()) { + return expr; + } + int64_t denom_val = denominator.as_int64(); + + // Recursively expand addition chain and collect all terms + std::vector terms = optim::GetFlattenExprs(numerator); + if (terms.empty()) { + return expr; + } + + // Separate terms that are multiples of denominator from other terms + std::vector multiple_terms; + std::vector remainder_terms; + + for (auto &term : terms) { + if (term.node_type() == ir::IrNodeTy::Mul) { + auto rhs = term.operand(1); + if (rhs.is_constant() && rhs.as_int64() % denom_val == 0) { + // Extract terms divisible by denominator + multiple_terms.push_back( + term.operand(0) * + (rhs.as_int64() / denom_val)); // Extract multiplicand part + continue; + } + } + // Extract terms not divisible by denominator + auto remainder_upper = ana.UpperBound(term); + if (!ana.ProveLT(remainder_upper, denominator).value_or(false)) { + return expr; + } + remainder_terms.push_back(term); + } + + // Build remainder expression + ir::IndexExpr remainder_expr; + if (remainder_terms.empty()) { + remainder_expr = ir::IndexExpr(0); + } else if (remainder_terms.size() == 1) { + remainder_expr = remainder_terms[0]; + } else { + remainder_expr = ir::Add::Make(remainder_terms[0], remainder_terms[1]); + for (size_t i = 2; i < remainder_terms.size(); ++i) { + remainder_expr = ir::Add::Make(remainder_expr, remainder_terms[i]); + } + } + + // Build multiplicand terms expression + ir::IndexExpr multiple_expr; + if (multiple_terms.empty()) { + multiple_expr = ir::IndexExpr(0); + } else if (multiple_terms.size() == 1) { + multiple_expr = multiple_terms[0]; + } else { + multiple_expr = ir::Add::Make(multiple_terms[0], multiple_terms[1]); + for (size_t i = 2; i < multiple_terms.size(); ++i) { + multiple_expr = ir::Add::Make(multiple_expr, multiple_terms[i]); + } + } + + // Verify if remainder range is less than denominator + auto remainder_upper = ana.UpperBound(remainder_expr); + if (!ana.ProveLT(remainder_upper, denominator).value_or(false)) { + // If remainder is greater than denominator, the division result is non-zero + if (expr.node_type() == ir::IrNodeTy::Div) { + return ir::Add::Make(multiple_expr, + ir::Div::Make(remainder_expr, denominator)); + } else { // Modulo operation + return ir::Mod::Make(remainder_expr, denominator); + } + } else { + // If remainder is less than denominator, the division result is zero + if (expr.node_type() == ir::IrNodeTy::Div) { + return multiple_expr; + } else { // Modulo operation + return remainder_expr; + } + } +} + ir::IndexExpr BoundSimplify(const ir::IndexExpr &expr) { - // return expr if expr is not a division or modulo + // Return expr if expr is not a division or modulo if (expr.node_type() != ir::IrNodeTy::Div && expr.node_type() != ir::IrNodeTy::Mod) return expr; @@ -686,10 +802,10 @@ ir::IndexExpr BoundSimplify(const ir::IndexExpr &expr) { common::cas_intervals_t var_intervals = common::CollectVarIntervalsOfExprs({expr}); common::SymbolicExprAnalyzer ana(var_intervals); - // Because the SymbolicExprAnalyzer bound result is [lower, upper), `ProveLE` - // is used here instead of `ProveLT`. + // Because the SymbolicExprAnalyzer bound result is [lower, upper], + // `ProveLT` is used here instead of `ProveLE`. auto canBeSimplified = - ana.ProveLE(ana.UpperBound(expr.operand(0)), expr.operand(1)); + ana.ProveLT(ana.UpperBound(expr.operand(0)), expr.operand(1)); if (canBeSimplified.value_or(false)) { if (expr.node_type() == ir::IrNodeTy::Div) { @@ -698,7 +814,8 @@ ir::IndexExpr BoundSimplify(const ir::IndexExpr &expr) { return expr.operand(0); } } - return expr; + + return HandleDivModWithConstants(expr, ana); } ir::IndexExpr BroadcastSimplify(const ir::IndexExpr &expr) { diff --git a/test/cpp/cinn/common/integer_set_test.cc b/test/cpp/cinn/common/integer_set_test.cc index 6d57f2dd0ed257..3f7afd4bcae50d 100644 --- a/test/cpp/cinn/common/integer_set_test.cc +++ b/test/cpp/cinn/common/integer_set_test.cc @@ -24,11 +24,13 @@ namespace common { class TestSymbolicExprAnalyzer : public ::testing::Test { public: void SetUp() override { - i = ir::Var(ir::Expr(0), ir::Expr(7), "i"); - j = ir::Var(ir::Expr(0), ir::Expr(15), "j"); + // Var is [lower_bound, upper_bound) + i = ir::Var(ir::Expr(0), ir::Expr(7), "i"); // i ∈ [0, 7) + j = ir::Var(ir::Expr(0), ir::Expr(15), "j"); // j ∈ [0, 15) + // CasInterval is [lower_bound, upper_bound] var_intervals = { - {"i", CasInterval(i->lower_bound, i->upper_bound)}, - {"j", CasInterval(j->lower_bound, j->upper_bound)}, + {"i", CasInterval(i->lower_bound, i->upper_bound - 1)}, // i ∈ [0, 6] + {"j", CasInterval(j->lower_bound, j->upper_bound - 1)}, // j ∈ [0, 14] }; } @@ -41,35 +43,35 @@ class TestSymbolicExprAnalyzer : public ::testing::Test { TEST_F(TestSymbolicExprAnalyzer, bound) { ir::Expr e1 = i + j; EXPECT_EQ(analyzer.LowerBound(e1), ir::Expr(0)); - EXPECT_EQ(analyzer.UpperBound(e1), ir::Expr(22)); + EXPECT_EQ(analyzer.UpperBound(e1), ir::Expr(20)); // 6 + 14 = 20 ir::Expr e2 = 16 * i + j; EXPECT_EQ(analyzer.LowerBound(e2), ir::Expr(0)); - EXPECT_EQ(analyzer.UpperBound(e2), ir::Expr(127)); + EXPECT_EQ(analyzer.UpperBound(e2), ir::Expr(110)); // 16 * 6 + 14 = 110 ir::Expr e3 = 16 * i + j + 1; EXPECT_EQ(analyzer.LowerBound(e3), ir::Expr(1)); - EXPECT_EQ(analyzer.UpperBound(e3), ir::Expr(128)); + EXPECT_EQ(analyzer.UpperBound(e3), ir::Expr(111)); // 16 * 6 + 15 = 111 ir::Expr e4 = (16 * i + j) / 16; EXPECT_EQ(analyzer.LowerBound(e4), ir::Expr(0)); - EXPECT_EQ(analyzer.UpperBound(e4), ir::Expr(7)); + EXPECT_EQ(analyzer.UpperBound(e4), ir::Expr(6)); // 110 / 16 = 6 ir::Expr e5 = (16 * i + j) % 16; EXPECT_EQ(analyzer.LowerBound(e5), ir::Expr(0)); - EXPECT_EQ(analyzer.UpperBound(e5), ir::Expr(15)); + EXPECT_EQ(analyzer.UpperBound(e5), ir::Expr(14)); // 110 % 16 ir::Expr e6 = i - j; - EXPECT_EQ(analyzer.LowerBound(e6), ir::Expr(-15)); - EXPECT_EQ(analyzer.UpperBound(e6), ir::Expr(7)); + EXPECT_EQ(analyzer.LowerBound(e6), ir::Expr(-14)); // 0 - 14 + EXPECT_EQ(analyzer.UpperBound(e6), ir::Expr(6)); // 6 - 0 ir::Expr e7 = 0 - i - j; - EXPECT_EQ(analyzer.LowerBound(e7), ir::Expr(-22)); - EXPECT_EQ(analyzer.UpperBound(e7), ir::Expr(0)); + EXPECT_EQ(analyzer.LowerBound(e7), ir::Expr(-20)); // 0 - 6 - 14 + EXPECT_EQ(analyzer.UpperBound(e7), ir::Expr(0)); // 0 - 0 - 0 ir::Expr e8 = -1 * i - j; - EXPECT_EQ(analyzer.LowerBound(e8), ir::Expr(-22)); - EXPECT_EQ(analyzer.UpperBound(e8), ir::Expr(0)); + EXPECT_EQ(analyzer.LowerBound(e8), ir::Expr(-20)); // -1 * 6 - 14 + EXPECT_EQ(analyzer.UpperBound(e8), ir::Expr(0)); // -1 * 0 - 0 } TEST_F(TestSymbolicExprAnalyzer, compare) { @@ -142,9 +144,9 @@ TEST_F(TestSymbolicExprAnalyzer, Divisible) { auto S = ir::Var(ir::Expr(16), ir::Expr(256), "S"); cas_intervals_t divisible_var_intervals = { - {"x", CasInterval(x->lower_bound, x->upper_bound)}, - {"y", CasInterval(y->lower_bound, y->upper_bound)}, - {"S", CasInterval(S->lower_bound, S->upper_bound)}, + {"x", CasInterval(x->lower_bound, x->upper_bound - ir::Expr(1))}, + {"y", CasInterval(y->lower_bound, y->upper_bound - ir::Expr(1))}, + {"S", CasInterval(S->lower_bound, S->upper_bound - ir::Expr(1))}, }; SymbolicExprAnalyzer divisible_analyzer{divisible_var_intervals}; @@ -323,11 +325,11 @@ TEST(SingleIntervalIntSet, case_1) { } TEST(SingleIntervalIntSet, case_2) { - ir::Var S = ir::Var(ir::Expr(0), ir::Expr(0), "S"); + ir::Var S = ir::Var(ir::Expr(0), ir::Expr(1), "S"); // S ∈ [0, 1) - SingleIntervalIntSet set_0{S, S + Expr(1)}; - SingleIntervalIntSet set_1{Expr(0), Expr(1)}; - SingleIntervalIntSet set_2{Expr(0), Expr(2)}; + SingleIntervalIntSet set_0{S, S + Expr(1)}; // [0, 1] + SingleIntervalIntSet set_1{Expr(0), Expr(1)}; // [0, 1] + SingleIntervalIntSet set_2{Expr(0), Expr(2)}; // [0, 2] EXPECT_TRUE(ProveEQ(set_0, set_1).value()); EXPECT_FALSE(ProveEQ(set_0, set_2).value()); diff --git a/test/cpp/pir/cinn/CMakeLists.txt b/test/cpp/pir/cinn/CMakeLists.txt index 0b3591f64f0f2c..2871d040551ee2 100644 --- a/test/cpp/pir/cinn/CMakeLists.txt +++ b/test/cpp/pir/cinn/CMakeLists.txt @@ -47,6 +47,10 @@ if(WITH_TESTING AND WITH_CINN) paddle_test(eliminate_common_factor_of_local_index_test SRCS eliminate_common_factor_of_local_index_test.cc) + paddle_test(ir_simplify_select_test SRCS ir_simplify_select_test.cc) + + paddle_test(ir_simplify_bound_test SRCS ir_simplify_bound_test.cc) + # DO NOT forget add test name here, otherwise it will not be executed in # CINN CI. set(cinn_unit_tests diff --git a/test/cpp/pir/cinn/adt/index_expr_test.cc b/test/cpp/pir/cinn/adt/index_expr_test.cc index 3bc2f4ab4e7ae3..a38041f669b20b 100644 --- a/test/cpp/pir/cinn/adt/index_expr_test.cc +++ b/test/cpp/pir/cinn/adt/index_expr_test.cc @@ -52,6 +52,7 @@ class TestIndexExpr : public ::testing::Test { ir::Var S4, S5, S6, S7, S8, S9, f; }; + TEST_F(TestIndexExpr, IndexExpr_0) { ir::IndexExpr a(14); ir::IndexExpr b(7); @@ -643,10 +644,11 @@ TEST_F(TestIndexExpr, MatchPattern) { EXPECT_EQ(result9->at("x"), x); EXPECT_EQ(result9->at("y"), y); } + TEST_F(TestIndexExpr, BoundSimplify) { ir::Var S0 = ir::Var("S0"); - ir::Var i = ir::Var(ir::Expr(0), ir::Expr(5), "i"); - ir::Var j = ir::Var(ir::Expr(0), S0, "j"); + ir::Var i = ir::Var(ir::Expr(0), ir::Expr(5), "i"); // i ∈ [0, 5) + ir::Var j = ir::Var(ir::Expr(0), S0, "j"); // j ∈ [0, S0) ir::Expr q0 = i / Expr(5); ir::Expr q1 = i / Expr(4); diff --git a/test/cpp/pir/cinn/adt/iter_simplify_test.cc b/test/cpp/pir/cinn/adt/iter_simplify_test.cc index 248855b703ff3b..b09bc9d6f521c7 100644 --- a/test/cpp/pir/cinn/adt/iter_simplify_test.cc +++ b/test/cpp/pir/cinn/adt/iter_simplify_test.cc @@ -47,11 +47,12 @@ class TestIterSimplify : public ::testing::Test { i_j_k_fused = ir::Var(ir::Expr(0), ir::Expr(64), "i_j_k_fused").set_index(1); var_intervals = { - {"i", CasInterval(i->lower_bound, i->upper_bound)}, - {"j", CasInterval(j->lower_bound, j->upper_bound)}, - {"k", CasInterval(k->lower_bound, k->upper_bound)}, + {"i", CasInterval(i->lower_bound, i->upper_bound - ir::Expr(1))}, + {"j", CasInterval(j->lower_bound, j->upper_bound - ir::Expr(1))}, + {"k", CasInterval(k->lower_bound, k->upper_bound - ir::Expr(1))}, {"i_j_k_fused", - CasInterval(i_j_k_fused->lower_bound, i_j_k_fused->upper_bound)}}; + CasInterval(i_j_k_fused->lower_bound, + i_j_k_fused->upper_bound - ir::Expr(1))}}; }; ir::Var i; diff --git a/test/cpp/pir/cinn/ir_simplify_bound_test.cc b/test/cpp/pir/cinn/ir_simplify_bound_test.cc new file mode 100644 index 00000000000000..42206af0b9d9b7 --- /dev/null +++ b/test/cpp/pir/cinn/ir_simplify_bound_test.cc @@ -0,0 +1,191 @@ +// Copyright (c) 2025 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/optim/ir_simplify.h" + +#include + +#include "paddle/cinn/cinn.h" +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/ir_base.h" +#include "paddle/cinn/ir/ir_printer.h" +#include "paddle/cinn/ir/op/ir_operators.h" +#include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/utils/ir_nodes_collector.h" +#include "paddle/cinn/ir/utils/stmt_converter.h" +#include "paddle/cinn/utils/string.h" + +namespace cinn { +namespace optim { + +/* +i_j_fused: [0ll, 524288ll) +j_0: [0, 128) +Before Normalize: +(j_0 % 128) +After Normalize: +j_0 +*/ +TEST(IRSimplifyBound, SimplifyMod) { + Context::Global().ResetNameId(); + + // Create input IR matching the specified pattern + // Define loop variable + ir::Var var_j_0 = ir::Var(ir::Expr(0), ir::Expr(128), "j_0"); + + // Final expression + ir::Expr expr = ir::Mod::Make(var_j_0, ir::Expr(128)); + + VLOG(6) << "Before Simplify: " << expr; + auto res = expr.as_index().ir::IndexExpr::Normalize( + ir::IndexExpr::OptLevel::kLevel3); + VLOG(6) << "After Simplify: " << res; + + // Expected output verification + std::string expected_ir = R"ROC(j_0)ROC"; + + EXPECT_EQ(utils::GetStreamCnt(res), utils::Trim(expected_ir)); +} + +/* +i_j_fused: [0ll, 524288ll) +j_0: [0, 128) +Before Normalize: +(j_0 / 128) +After Normalize: +0 +*/ +TEST(IRSimplifyBound, SimplifyDiv) { + Context::Global().ResetNameId(); + + // Create input IR matching the specified pattern + // Define loop variable + ir::Var var_j_0 = ir::Var(ir::Expr(0), ir::Expr(128), "j_0"); + + // Final expression + ir::Expr expr = ir::Div::Make(var_j_0, ir::Expr(128)); + + VLOG(6) << "Before Normalize: " << expr; + auto res = expr.as_index().ir::IndexExpr::Normalize( + ir::IndexExpr::OptLevel::kLevel3); + VLOG(6) << "After Normalize: " << res; + + // Expected output verification + std::string expected_ir = R"ROC(0)ROC"; + + EXPECT_EQ(utils::GetStreamCnt(res), utils::Trim(expected_ir)); +} + +/* +i_j_fused: [0ll, 524288ll) +j_0: [0, 128) +Before Normalize: +((((i_j_fused % 16) * 128) + j_0) / 128) +After Normalize: +(i_j_fused % 16) +*/ +TEST(IRSimplifyBound, SimplifyLinearDiv) { + Context::Global().ResetNameId(); + + // Create input IR matching the specified pattern + // Define loop variables + ir::Var var_i_j_fused = ir::Var(ir::Expr(0), ir::Expr(524288), "i_j_fused"); + ir::Var var_j_0 = ir::Var(ir::Expr(0), ir::Expr(128), "j_0"); + + // Final expression + ir::Expr expr = ir::Div::Make( + ir::Add::Make(ir::Mul::Make(ir::Mod::Make(var_i_j_fused, ir::Expr(16)), + ir::Expr(128)), + var_j_0), + ir::Expr(128)); + + VLOG(6) << "Before Normalize: " << expr; + auto res = expr.as_index().ir::IndexExpr::Normalize( + ir::IndexExpr::OptLevel::kLevel3); + VLOG(6) << "After Normalize: " << res; + + // Expected output verification + std::string expected_ir = R"ROC((i_j_fused % 16))ROC"; + + EXPECT_EQ(utils::GetStreamCnt(res), utils::Trim(expected_ir)); +} + +/* +i_j_fused: [0ll, 524288ll) +j_0: [0, 128) +Before Normalize: +((((i_j_fused % 16) * 128) + j_0) % 128) +After Normalize: +j_0 +*/ +TEST(IRSimplifyBound, SimplifyLinearMod) { + Context::Global().ResetNameId(); + + // Create input IR matching the specified pattern + // Define loop variables + ir::Var var_i_j_fused = ir::Var(ir::Expr(0), ir::Expr(524288), "i_j_fused"); + ir::Var var_j_0 = ir::Var(ir::Expr(0), ir::Expr(128), "j_0"); + + // Final expression + ir::Expr expr = ir::Mod::Make( + ir::Add::Make(ir::Mul::Make(ir::Mod::Make(var_i_j_fused, ir::Expr(16)), + ir::Expr(128)), + var_j_0), + ir::Expr(128)); + + VLOG(6) << "Before Normalize: " << expr; + auto res = expr.as_index().ir::IndexExpr::Normalize( + ir::IndexExpr::OptLevel::kLevel3); + VLOG(6) << "After Normalize: " << res; + + // Expected output verification + std::string expected_ir = R"ROC(j_0)ROC"; + + EXPECT_EQ(utils::GetStreamCnt(res), utils::Trim(expected_ir)); +} + +/* +loop_var_2: [0, 32) +loop_var_3: [0, 4) +Before Normalize: +(((loop_var_3 * 32ll) + loop_var_2) / 128ll) +After Normalize: +0 +*/ +TEST(IRSimplifyBound, SimplifyLinearDiv2) { + Context::Global().ResetNameId(); + + // Create input IR matching the specified pattern + // Define loop variables + ir::Var loop_var_2 = ir::Var(ir::Expr(0), ir::Expr(32), "loop_var_2"); + ir::Var loop_var_3 = ir::Var(ir::Expr(0), ir::Expr(4), "loop_var_3"); + + // Final expression + ir::Expr expr = ir::Div::Make( + ir::Add::Make(ir::Mul::Make(loop_var_3, ir::Expr(32)), loop_var_2), + ir::Expr(128)); + + VLOG(6) << "Before Normalize: " << expr; + auto res = expr.as_index().ir::IndexExpr::Normalize( + ir::IndexExpr::OptLevel::kLevel3); + VLOG(6) << "After Normalize: " << res; + + // Expected output verification + std::string expected_ir = R"ROC(0)ROC"; + + EXPECT_EQ(utils::GetStreamCnt(res), utils::Trim(expected_ir)); +} + +} // namespace optim +} // namespace cinn diff --git a/test/cpp/pir/cinn/ir_simplify_select_test.cc b/test/cpp/pir/cinn/ir_simplify_select_test.cc new file mode 100644 index 00000000000000..0f236e9d266865 --- /dev/null +++ b/test/cpp/pir/cinn/ir_simplify_select_test.cc @@ -0,0 +1,336 @@ +// Copyright (c) 2025 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/optim/ir_simplify.h" + +#include + +#include "paddle/cinn/cinn.h" +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/ir_printer.h" +#include "paddle/cinn/ir/op/ir_operators.h" +#include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/utils/ir_nodes_collector.h" +#include "paddle/cinn/ir/utils/stmt_converter.h" +#include "paddle/cinn/utils/string.h" + +namespace cinn { +namespace optim { + +/* +serial for (i, 0ll, 32768ll) { + serial for (j, 0ll, 16ll) { + serial for (reduce_k_0, 0ll, 128ll) { + var_18[i, j] = select((var_18[i, j] > var_17[i, j, reduce_k_0]), +var_18[i, j], var_17[i, j, reduce_k_0]) + } + } + } +} +*/ +TEST(IRSimplifySelect, SimplifySelectToMax) { + Context::Global().ResetNameId(); + + // Create input IR matching the specified pattern + const std::vector shape_2d = {ir::Expr(32768), ir::Expr(16)}; + const std::vector shape_3d = { + ir::Expr(32768), ir::Expr(16), ir::Expr(128)}; + + ir::Tensor var_17 = + ir::_Tensor_::Make("var_17", ir::Float(32), shape_3d, shape_3d); + var_17->WithBuffer("global", "var_17_buffer"); + + ir::Tensor var_18 = + ir::_Tensor_::Make("var_18", ir::Float(32), shape_2d, shape_2d); + var_18->WithBuffer("global", "var_18_buffer"); + + // Define loop variables + ir::Var var_i = ir::Var(ir::Expr(0), ir::Expr(32768), "i"); + ir::Var var_j = ir::Var(ir::Expr(0), ir::Expr(16), "j"); + ir::Var var_reduce_k_0 = ir::Var(ir::Expr(0), ir::Expr(128), "reduce_k_0"); + + // Create innermost reduction loop body + ir::Expr reduce_body = ir::Store::Make( + var_18, + ir::Select::Make( + ir::GT::Make(ir::Load::Make(var_18, {var_i, var_j}), + ir::Load::Make(var_17, {var_i, var_j, var_reduce_k_0})), + ir::Load::Make(var_18, {var_i, var_j}), + ir::Load::Make(var_17, {var_i, var_j, var_reduce_k_0})), + {var_i, var_j}); + + // Create reduction loop + ir::Expr reduce_loop = ir::For::Make(var_reduce_k_0, + ir::Expr(0), + ir::Expr(128), + ir::ForType::Serial, + ir::DeviceAPI::Host, + ir::Block::Make({reduce_body})); + + // Create j loop + ir::Expr j_loop = ir::For::Make(var_j, + ir::Expr(0), + ir::Expr(16), + ir::ForType::Serial, + ir::DeviceAPI::Host, + ir::Block::Make({reduce_loop})); + + // Create i loop + ir::Expr i_loop = ir::For::Make(var_i, + ir::Expr(0), + ir::Expr(32768), + ir::ForType::Serial, + ir::DeviceAPI::Host, + ir::Block::Make({j_loop})); + + // Final expression + ir::Expr expr = ir::Block::Make({i_loop}); + + VLOG(6) << "Before Simplify: " << expr; + Simplify(&expr); + VLOG(6) << "After Simplify: " << expr; + + // Expected output verification + std::string expected_ir = R"ROC({ + serial for (i, 0, 32768) + { + serial for (j, 0, 16) + { + serial for (reduce_k_0, 0, 128) + { + var_18[i, j] = cinn_max(var_17[i, j, reduce_k_0], var_18[i, j]) + } + } + } +})ROC"; + + EXPECT_EQ(utils::GetStreamCnt(expr), utils::Trim(expected_ir)); +} + +/* +serial for (i, 0ll, 32768ll) { + serial for (j, 0ll, 16ll) { + serial for (reduce_k_0, 0ll, 128ll) { + var_18[i, j] = select((var_18[i, j] < var_17[i, j, reduce_k_0]), +var_18[i, j], var_17[i, j, reduce_k_0]) + } + } + } +} +*/ +TEST(IRSimplifySelect, SimplifySelectToMin) { + Context::Global().ResetNameId(); + + // Create input IR matching the specified pattern + const std::vector shape_2d = {ir::Expr(32768), ir::Expr(16)}; + const std::vector shape_3d = { + ir::Expr(32768), ir::Expr(16), ir::Expr(128)}; + + ir::Tensor var_17 = + ir::_Tensor_::Make("var_17", ir::Float(32), shape_3d, shape_3d); + var_17->WithBuffer("global", "var_17_buffer"); + + ir::Tensor var_18 = + ir::_Tensor_::Make("var_18", ir::Float(32), shape_2d, shape_2d); + var_18->WithBuffer("global", "var_18_buffer"); + + // Define loop variables + ir::Var var_i = ir::Var(ir::Expr(0), ir::Expr(32768), "i"); + ir::Var var_j = ir::Var(ir::Expr(0), ir::Expr(16), "j"); + ir::Var var_reduce_k_0 = ir::Var(ir::Expr(0), ir::Expr(128), "reduce_k_0"); + + // Create innermost reduction loop body + ir::Expr reduce_body = ir::Store::Make( + var_18, + ir::Select::Make( + ir::LT::Make(ir::Load::Make(var_18, {var_i, var_j}), + ir::Load::Make(var_17, {var_i, var_j, var_reduce_k_0})), + ir::Load::Make(var_18, {var_i, var_j}), + ir::Load::Make(var_17, {var_i, var_j, var_reduce_k_0})), + {var_i, var_j}); + + // Create reduction loop + ir::Expr reduce_loop = ir::For::Make(var_reduce_k_0, + ir::Expr(0), + ir::Expr(128), + ir::ForType::Serial, + ir::DeviceAPI::Host, + ir::Block::Make({reduce_body})); + + // Create j loop + ir::Expr j_loop = ir::For::Make(var_j, + ir::Expr(0), + ir::Expr(16), + ir::ForType::Serial, + ir::DeviceAPI::Host, + ir::Block::Make({reduce_loop})); + + // Create i loop + ir::Expr i_loop = ir::For::Make(var_i, + ir::Expr(0), + ir::Expr(32768), + ir::ForType::Serial, + ir::DeviceAPI::Host, + ir::Block::Make({j_loop})); + + // Final expression + ir::Expr expr = ir::Block::Make({i_loop}); + + VLOG(6) << "Before Simplify: " << expr; + Simplify(&expr); + VLOG(6) << "After Simplify: " << expr; + + // Expected output verification + std::string expected_ir = R"ROC({ + serial for (i, 0, 32768) + { + serial for (j, 0, 16) + { + serial for (reduce_k_0, 0, 128) + { + var_18[i, j] = cinn_min(var_18[i, j], var_17[i, j, reduce_k_0]) + } + } + } +})ROC"; + + EXPECT_EQ(utils::GetStreamCnt(expr), utils::Trim(expected_ir)); +} + +/* +serial for (i, 0ll, 32768ll) +{ + serial for (j, 0, 16) + { + serial for (j_0, 0, 128) + { + var_45[i, j, j_0)] = select( + (var_18[i, ((((j * 128ll) + j_0) / 128ll) + 0ll)] <= + float32(3.4028234663852886e+38)), + select( + (var_18[i, ((((j * 128ll) + j_0) / 128ll) + 0ll)] >= + float32(9.9999997473787516e-05)), + var_18[i, ((((j * 128ll) + j_0) / 128ll) + 0ll)], + float32(9.9999997473787516e-05) + ), + float32(3.4028234663852886e+38) + ) + } + } +} +*/ +TEST(IRSimplifySelect, SimplifySelectToMinMax) { + Context::Global().ResetNameId(); + + // Create input IR matching the specified pattern + const std::vector shape_2d = {ir::Expr(32768), ir::Expr(16)}; + const std::vector shape_3d = { + ir::Expr(32768), ir::Expr(16), ir::Expr(128)}; + + ir::Tensor var_18 = + ir::_Tensor_::Make("var_18", ir::Float(32), shape_2d, shape_2d); + var_18->WithBuffer("global", "var_18_buffer"); + + ir::Tensor var_45 = + ir::_Tensor_::Make("var_45", ir::Float(32), shape_3d, shape_3d); + var_45->WithBuffer("global", "var_45_buffer"); + + // Define loop variables + ir::Var var_i = ir::Var(ir::Expr(0), ir::Expr(32768), "i"); + ir::Var var_j = ir::Var(ir::Expr(0), ir::Expr(16), "j"); + ir::Var var_j_0 = ir::Var(ir::Expr(0), ir::Expr(128), "j_0"); + + // Create innermost loop body + ir::Expr body = ir::Store::Make( + var_45, + ir::Select::Make( + ir::LE::Make( + ir::Load::Make( + var_18, + {var_i, + ir::Div::Make( + ir::Add::Make(ir::Mul::Make(var_j, ir::Expr(128)), + var_j_0), + ir::Expr(128))}), + ir::Expr(3.4028234663852886e+38f)), + ir::Select::Make( + ir::GE::Make( + ir::Load::Make( + var_18, + {var_i, + ir::Div::Make( + ir::Add::Make(ir::Mul::Make(var_j, ir::Expr(128)), + var_j_0), + ir::Expr(128))}), + ir::Expr(9.9999997473787516e-05f)), + ir::Load::Make( + var_18, + {var_i, + ir::Div::Make( + ir::Add::Make(ir::Mul::Make(var_j, ir::Expr(128)), + var_j_0), + ir::Expr(128))}), + ir::Expr(9.9999997473787516e-05f)), + ir::Expr(3.4028234663852886e+38f)), + {var_i, var_j, var_j_0}); + + // Create j_0 loop + ir::Expr j_0_loop = ir::For::Make(var_j_0, + ir::Expr(0), + ir::Expr(128), + ir::ForType::Serial, + ir::DeviceAPI::Host, + ir::Block::Make({body})); + + // Create j loop + ir::Expr j_loop = ir::For::Make(var_j, + ir::Expr(0), + ir::Expr(16), + ir::ForType::Serial, + ir::DeviceAPI::Host, + ir::Block::Make({j_0_loop})); + + // Create i loop + ir::Expr i_loop = ir::For::Make(var_i, + ir::Expr(0), + ir::Expr(32768), + ir::ForType::Serial, + ir::DeviceAPI::Host, + ir::Block::Make({j_loop})); + + // Final expression + ir::Expr expr = ir::Block::Make({i_loop}); + + VLOG(6) << "Before Simplify: " << expr; + Simplify(&expr); + VLOG(6) << "After Simplify: " << expr; + + // Expected output verification + std::string expected_ir = R"ROC({ + serial for (i, 0, 32768) + { + serial for (j, 0, 16) + { + serial for (j_0, 0, 128) + { + var_45[i, j, j_0] = cinn_min(cinn_max(var_18[i, (((j * 128) + j_0) / 128)], 9.99999975e-05f), 3.40282347e+38f) + } + } + } +})ROC"; + + EXPECT_EQ(utils::GetStreamCnt(expr), utils::Trim(expected_ir)); +} +} // namespace optim +} // namespace cinn diff --git a/test/cpp/pir/cinn/ir_simplify_test.cc b/test/cpp/pir/cinn/ir_simplify_test.cc index e682079e72a90a..485216814f0102 100644 --- a/test/cpp/pir/cinn/ir_simplify_test.cc +++ b/test/cpp/pir/cinn/ir_simplify_test.cc @@ -479,5 +479,98 @@ TEST(IRSimplify, if_fold_EQ_2) { } )ROC")); } + +/* +serial for (i_j_fused, 0ll, 524288ll) +{ + serial for (j_0, 0, 128) + { + var_45[(i_j_fused / 16), (((i_j_fused % 16) * 128) + j_0)] = + pow(2.0f, ceil(log2((0.00223214296f * var_31[0])))) + } + } +*/ +TEST(IRSimplifyPowerCeilLog2BitOpLdexpf, Base) { + Context::Global().ResetNameId(); + + /// Create input IR matching the specified pattern + const std::vector shape_2d = {ir::Expr(32768), ir::Expr(16)}; + const std::vector shape_3d = {ir::Expr(32768), ir::Expr(16)}; + + ir::Tensor var_31 = + ir::_Tensor_::Make("var_31", ir::Float(32), shape_2d, shape_2d); + var_31->WithBuffer("global", "var_31_buffer"); + + ir::Tensor var_45 = + ir::_Tensor_::Make("var_45", ir::Float(32), shape_3d, shape_3d); + var_45->WithBuffer("global", "var_45_buffer"); + + // Define loop variables + ir::Var var_i_j_fused = ir::Var(ir::Expr(0), ir::Expr(524288), "i_j_fused"); + ir::Var var_j_0 = ir::Var(ir::Expr(0), ir::Expr(128), "j_0"); + + // Create innermost loop body + ir::Expr body = ir::Store::Make( + var_45, + ir::Call::Make( + ir::Float(32), // Return type + "pow", // Intrinsic function name + {ir::Expr(2.0f), + ir::Call::Make( + ir::Float(32), + "ceil", + {ir::Call::Make( + ir::Float(32), + "log2", + {ir::Mul::Make(ir::Expr(0.00223214296f), + ir::Load::Make(var_31, {ir::Expr(0)}))}, + {}, + ir::CallType::Intrinsic)}, + {}, + ir::CallType::Intrinsic)}, + {}, + ir::CallType::Intrinsic), + {ir::Div::Make(var_i_j_fused, ir::Expr(16)), + ir::Add::Make(ir::Mul::Make(ir::Mod::Make(var_i_j_fused, ir::Expr(16)), + ir::Expr(128)), + var_j_0)}); + + // Create j_0 loop + ir::Expr j_0_loop = ir::For::Make(var_j_0, + ir::Expr(0), + ir::Expr(128), + ir::ForType::Serial, + ir::DeviceAPI::Host, + ir::Block::Make({body})); + + // Create i_j_fused loop + ir::Expr i_j_fused_loop = ir::For::Make(var_i_j_fused, + ir::Expr(0), + ir::Expr(524288), + ir::ForType::Serial, + ir::DeviceAPI::Host, + ir::Block::Make({j_0_loop})); + + // Final expression + ir::Expr expr = ir::Block::Make({i_j_fused_loop}); + + VLOG(6) << "Before Simplify: " << expr; + cinn::optim::Simplify(&expr); + VLOG(6) << "After Simplify: " << expr; + + // Expected output verification + std::string expected_ir = R"ROC({ + serial for (i_j_fused, 0, 524288) + { + serial for (j_0, 0, 128) + { + var_45[(i_j_fused / 16), (((i_j_fused % 16) * 128) + j_0)] = ldexpf(1.00000000f, ((bitwise_and(right_shift(__float_as_uint((0.00223214296f * var_31[0])), 23), 255) - 127) + select((((bitwise_and(right_shift(__float_as_uint((0.00223214296f * var_31[0])), 23), 255) - 127) != -127) and (bitwise_and(__float_as_uint((0.00223214296f * var_31[0])), 8388607) != 0)), 1, 0))) + } + } +})ROC"; + + EXPECT_EQ(utils::GetStreamCnt(expr), utils::Trim(expected_ir)); +} + } // namespace common } // namespace cinn