Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 38 additions & 20 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -673,32 +673,42 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
} else if (op->op.same_as(builtin::shift_right())) {
PrintBinaryIntrinsic(op, " >> ", os, this);
} else if (op->op.same_as(builtin::if_then_else())) {
// conditional that skips eval if cond evals to false
// Conditional that skips eval if cond evals to false.
// When inside a select, combine conditions to prevent OOB access.
std::string result = name_supply_->FreshName("condval");
std::string cond = PrintExpr(op->args[0]);
std::string outer_cond = select_condition_stack_.empty() ? "" : select_condition_stack_.back();

this->PrintIndent();
PrintType(op->dtype, this->stream);
this->stream << " " << result << ";\n";

// Generate if condition (combine with outer select condition if present)
this->PrintIndent();
this->stream << "if (" << cond << ") {\n";
{
int then_scope = this->BeginScope();
std::string true_val = PrintExpr(op->args[1]);
this->PrintIndent();
this->stream << result << " = " << true_val << ";\n";
this->EndScope(then_scope);
this->PrintIndent();
this->stream << "} else {\n";
}
{
int else_scope = this->BeginScope();
std::string false_val = PrintExpr(op->args[2]);
this->PrintIndent();
this->stream << result << " = " << false_val << ";\n";
this->EndScope(else_scope);
this->PrintIndent();
this->stream << "}\n";
if (outer_cond.empty()) {
this->stream << "if (" << cond << ") {\n";
} else {
this->stream << "if ((" << outer_cond << ") && (" << cond << ")) {\n";
}

// True branch
int then_scope = this->BeginScope();
std::string true_val = PrintExpr(op->args[1]);
this->PrintIndent();
this->stream << result << " = " << true_val << ";\n";
this->EndScope(then_scope);

// False branch
this->PrintIndent();
this->stream << (outer_cond.empty() ? "} else {\n" : "} else if (" + outer_cond + ") {\n");
int else_scope = this->BeginScope();
std::string false_val = PrintExpr(op->args[2]);
this->PrintIndent();
this->stream << result << " = " << false_val << ";\n";
this->EndScope(else_scope);
this->PrintIndent();
this->stream << "}\n";

os << result;
} else if (op->op.same_as(builtin::address_of())) {
const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
Expand Down Expand Up @@ -1059,12 +1069,20 @@ void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLIN
}

void CodeGenC::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*)
std::string cond = PrintExpr(op->condition);
os << "(";
PrintExpr(op->condition, os);
os << cond;
os << " ? ";
// Push condition before processing true_value so that nested if_then_else
// can guard their branches with this condition
select_condition_stack_.push_back(cond);
PrintExpr(op->true_value, os);
select_condition_stack_.pop_back();
os << " : ";
// Push negated condition for false_value
select_condition_stack_.push_back("!(" + cond + ")");
PrintExpr(op->false_value, os);
select_condition_stack_.pop_back();
os << ")";
}

Expand Down
8 changes: 8 additions & 0 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,14 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
bool print_ssa_form_{false};
/*! \brief whether the module has a main function declared */
bool has_tvm_ffi_main_func_{false};
/*! \brief Stack of select conditions for if_then_else codegen.
*
* When processing select(cond, true_value, false_value), we push the condition
* before processing true_value. This allows nested if_then_else to guard their
* branches with the outer select condition, preventing potential out-of-bounds
* access when the outer condition is false.
*/
std::vector<std::string> select_condition_stack_;

private:
/*! \brief set of volatile buf access */
Expand Down