Skip to content
81 changes: 10 additions & 71 deletions src/target/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,96 +348,35 @@ void CodeGenCHost::VisitExpr_(const tvm::tir::CallNode *op,
}

void CodeGenCHost::VisitStmt_(const tvm::tir::AssertStmtNode *op) { // NOLINT(*)
using namespace tvm::tir;
if (emit_asserts_) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
stream << "if (!(" << cond << ")) {\n";
int assert_if_scope = this->BeginScope();
{
// Prepare the base error message
const auto *msg_node = op->message.as<StringImmNode>();
const auto *msg_node = op->message.as<tvm::tir::StringImmNode>();
ICHECK(msg_node != nullptr) << "Assert message expected to be StringImm";
const std::string &raw_msg = msg_node->value;
const std::string esc_msg = tvm::support::StrEscape(
raw_msg.c_str(), raw_msg.length(), /*use_octal_escape=*/true,
/*escape_whitespace_special_chars=*/true);

// If the assertion condition contains any equality checks anywhere
// in a composite boolean expression, append the actual LHS/RHS values
// Collect all EQ nodes within the condition (including inside And/Or/Not)
std::vector<const EQNode *> eq_nodes;
{
std::vector<PrimExpr> stk;
stk.push_back(op->condition);
while (!stk.empty()) {
PrimExpr cur = stk.back();
stk.pop_back();
if (const auto *eq = cur.as<EQNode>()) {
eq_nodes.push_back(eq);
continue;
}
if (const auto *an = cur.as<AndNode>()) {
stk.push_back(an->a);
stk.push_back(an->b);
continue;
}
if (const auto *on = cur.as<OrNode>()) {
stk.push_back(on->a);
stk.push_back(on->b);
continue;
}
if (const auto *nn = cur.as<NotNode>()) {
stk.push_back(nn->a);
continue;
}
}
}

if (!eq_nodes.empty()) {
// Build a single detailed message that includes all LHS/RHS pairs
// If the assertion is an equality check, append the actual LHS/RHS values
if (const auto *eq = op->condition.as<tvm::tir::EQNode>()) {
std::string lhs = PrintExpr(eq->a);
std::string rhs = PrintExpr(eq->b);
PrintIndent();
stream << "char __tvm_assert_msg_buf[1024];\n";
stream << "char __tvm_assert_msg_buf[512];\n";
PrintIndent();
stream << "int __tvm_assert_msg_len = snprintf(__tvm_assert_msg_buf, "
"sizeof(__tvm_assert_msg_buf), \"%s\", \""
<< esc_msg << "\");\n";

auto escape_for_printf_literal = [&](const std::string &s) {
std::string out;
out.reserve(s.size());
for (char c : s) {
if (c == '%') {
out += "%%";
} else if (c == '"') {
out += "\\\"";
} else if (c == '\\') {
out += "\\\\";
} else {
out.push_back(c);
}
}
return out;
};

for (const auto *eq : eq_nodes) {
std::string lhs = PrintExpr(eq->a);
std::string rhs = PrintExpr(eq->b);
std::string lhs_disp = escape_for_printf_literal(lhs);
std::string rhs_disp = escape_for_printf_literal(rhs);
PrintIndent();
stream << "__tvm_assert_msg_len += snprintf(__tvm_assert_msg_buf + "
"__tvm_assert_msg_len, "
"sizeof(__tvm_assert_msg_buf) - __tvm_assert_msg_len, \"; ("
<< lhs_disp << " == " << rhs_disp
<< ") got: %lld, expected: %lld\", (long long)(" << lhs
<< "), (long long)(" << rhs << "));\n";
}
stream << "snprintf(__tvm_assert_msg_buf, 512, \"%s; expected: %lld, "
"got: %lld\", \""
<< esc_msg << "\", (long long)(" << lhs << "), (long long)("
<< rhs << "));\n";
PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", "
"__tvm_assert_msg_buf);\n";
} else {
// Fallback: just emit the base message
PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \"" << esc_msg
<< "\");\n";
Expand Down
Loading
Loading