diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 7c15e23e4ac8..2fe8e44dac57 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -673,32 +673,42 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) } else if (op->op.same_as(builtin::shift_right())) { PrintBinaryIntrinsic(op, " >> ", os, this); } else if (op->op.same_as(builtin::if_then_else())) { - // conditional that skips eval if cond evals to false + // Conditional that skips eval if cond evals to false. + // When inside a select, combine conditions to prevent OOB access. std::string result = name_supply_->FreshName("condval"); std::string cond = PrintExpr(op->args[0]); + std::string outer_cond = select_condition_stack_.empty() ? "" : select_condition_stack_.back(); + this->PrintIndent(); PrintType(op->dtype, this->stream); this->stream << " " << result << ";\n"; + + // Generate if condition (combine with outer select condition if present) this->PrintIndent(); - this->stream << "if (" << cond << ") {\n"; - { - int then_scope = this->BeginScope(); - std::string true_val = PrintExpr(op->args[1]); - this->PrintIndent(); - this->stream << result << " = " << true_val << ";\n"; - this->EndScope(then_scope); - this->PrintIndent(); - this->stream << "} else {\n"; - } - { - int else_scope = this->BeginScope(); - std::string false_val = PrintExpr(op->args[2]); - this->PrintIndent(); - this->stream << result << " = " << false_val << ";\n"; - this->EndScope(else_scope); - this->PrintIndent(); - this->stream << "}\n"; + if (outer_cond.empty()) { + this->stream << "if (" << cond << ") {\n"; + } else { + this->stream << "if ((" << outer_cond << ") && (" << cond << ")) {\n"; } + + // True branch + int then_scope = this->BeginScope(); + std::string true_val = PrintExpr(op->args[1]); + this->PrintIndent(); + this->stream << result << " = " << true_val << ";\n"; + this->EndScope(then_scope); + + // False branch + this->PrintIndent(); + this->stream << (outer_cond.empty() ? "} else {\n" : "} else if (" + outer_cond + ") {\n"); + int else_scope = this->BeginScope(); + std::string false_val = PrintExpr(op->args[2]); + this->PrintIndent(); + this->stream << result << " = " << false_val << ";\n"; + this->EndScope(else_scope); + this->PrintIndent(); + this->stream << "}\n"; + os << result; } else if (op->op.same_as(builtin::address_of())) { const BufferLoadNode* load = op->args[0].as(); @@ -1059,12 +1069,20 @@ void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLIN } void CodeGenC::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*) + std::string cond = PrintExpr(op->condition); os << "("; - PrintExpr(op->condition, os); + os << cond; os << " ? "; + // Push condition before processing true_value so that nested if_then_else + // can guard their branches with this condition + select_condition_stack_.push_back(cond); PrintExpr(op->true_value, os); + select_condition_stack_.pop_back(); os << " : "; + // Push negated condition for false_value + select_condition_stack_.push_back("!(" + cond + ")"); PrintExpr(op->false_value, os); + select_condition_stack_.pop_back(); os << ")"; } diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 920e6a13a04e..50bd98afccc5 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -322,6 +322,14 @@ class CodeGenC : public ExprFunctor, bool print_ssa_form_{false}; /*! \brief whether the module has a main function declared */ bool has_tvm_ffi_main_func_{false}; + /*! \brief Stack of select conditions for if_then_else codegen. + * + * When processing select(cond, true_value, false_value), we push the condition + * before processing true_value. This allows nested if_then_else to guard their + * branches with the outer select condition, preventing potential out-of-bounds + * access when the outer condition is false. + */ + std::vector select_condition_stack_; private: /*! \brief set of volatile buf access */