Skip to content

Commit 5457abb

Browse files
committed
[Codegen] Emit tir::Let as var assignment explicitly
Prior to this PR, the PrimExpr `tir::Let` is treated as inlining during codegen, which makes any common subexpression elimination (CSE) efforts using `tir::Let` at TIR level effectless. This PR updates codegen so that the `tir::Let` will have an explicit var assignment and thus can effectively reflect the CSE efforts.
1 parent 132daf6 commit 5457abb

File tree

3 files changed

+25
-6
lines changed

3 files changed

+25
-6
lines changed

python/tvm/relax/frontend/nn/op.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2544,7 +2544,7 @@ def _cumsum_mask(cumsum_sorted, top_p, top_k, i, j):
25442544

25452545
@T.prim_func(private=True)
25462546
def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
2547-
batch, vocab_size = T.int64(), T.int64()
2547+
batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
25482548
cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype)
25492549
top_p = T.match_buffer(B, (batch, 1), prob_dtype)
25502550
top_k = T.match_buffer(C, (batch, 1), index_dtype)
@@ -2564,8 +2564,8 @@ def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
25642564
def _get_index_from_sorted(
25652565
A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle
25662566
):
2567-
batch, vocab_size = T.int64(), T.int64()
2568-
out_batch = T.int64()
2567+
batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
2568+
out_batch = T.int64(is_size_var=True)
25692569
cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype)
25702570
indices = T.match_buffer(B, (batch, vocab_size), index_dtype)
25712571
renorm_prob = T.match_buffer(C, (batch, 1), prob_dtype)

src/target/source/codegen_c.cc

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -887,8 +887,27 @@ void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*)
887887
let_binding_[op->var] = op;
888888
}
889889
std::string value = PrintExpr(op->value);
890-
var_idmap_[op->var.get()] = value;
890+
if (print_ssa_form_) {
891+
ICHECK(!var_idmap_.count(op->var.get()));
892+
var_idmap_[op->var.get()] = value;
893+
} else {
894+
PrintIndent();
895+
if (op->var.dtype() == DataType::Handle() && handle_data_type_.count(op->var.get())) {
896+
PrintType(handle_data_type_.at(op->var.get()), this->stream);
897+
this->stream << "* " << AllocVarID(op->var.get()) << " = (";
898+
PrintType(handle_data_type_.at(op->var.get()), this->stream);
899+
this->stream << "*)" << value << ";\n";
900+
} else {
901+
PrintType(op->var.dtype(), this->stream);
902+
this->stream << ' ' << AllocVarID(op->var.get()) << " = " << value << ";\n";
903+
}
904+
}
891905
os << PrintExpr(op->body);
906+
// Pop the defined var from var_idmap when exiting its scope.
907+
// We do this because it is hard to completely avoid a same LetNode appearing
908+
// at different places.
909+
bool removed = var_idmap_.erase(op->var.get());
910+
ICHECK(removed);
892911
}
893912

894913
void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*)

tests/python/relax/test_frontend_nn_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -947,7 +947,7 @@ def foo(
947947
class Expected:
948948
@T.prim_func(private=True)
949949
def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle):
950-
batch, vocab_size = T.int64(), T.int64()
950+
batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
951951
cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
952952
indices = T.match_buffer(B, (batch, vocab_size), "int64")
953953
renorm_prob = T.match_buffer(C, (batch, 1))
@@ -970,7 +970,7 @@ def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E:
970970

971971
@T.prim_func(private=True)
972972
def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
973-
batch, vocab_size = T.int64(), T.int64()
973+
batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
974974
cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
975975
top_p = T.match_buffer(B, (batch, 1))
976976
top_k = T.match_buffer(C, (batch, 1), "int64")

0 commit comments

Comments
 (0)