Skip to content

Commit 47e964a

Browse files
[Codegen][WebGPU] LetNode common subexpr override (#17302)
This PR overrides the WebGPU codegen function of `tir::LetNode` to adapt to the recent LetNode common subexpression changes. Co-authored-by: Ruihang Lai <[email protected]>
1 parent 541f9c2 commit 47e964a

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

src/target/source/codegen_webgpu.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,27 @@ void CodeGenWebGPU::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOL
433433
<< PrintExpr(op->condition) << ")";
434434
}
435435

436+
void CodeGenWebGPU::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*)
437+
// use ssa form.
438+
if (print_ssa_form_) {
439+
std::string value = PrintExpr(op->value);
440+
ICHECK(!var_idmap_.count(op->var.get()));
441+
var_idmap_[op->var.get()] = value;
442+
} else {
443+
PrintIndent();
444+
std::string value = PrintExpr(op->value);
445+
this->stream << "let " << AllocVarID(op->var.get()) << " : ";
446+
PrintType(op->var.dtype(), this->stream);
447+
this->stream << " = " << value << ";\n";
448+
}
449+
os << PrintExpr(op->body);
450+
// Pop the defined var from var_idmap when exiting its scope.
451+
// We do this because it is hard to completely avoid a same LetNode appearing
452+
// at different places.
453+
bool removed = var_idmap_.erase(op->var.get());
454+
ICHECK(removed);
455+
}
456+
436457
void CodeGenWebGPU::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*)
437458
if (op->dtype.bits() == 32) {
438459
std::ostringstream temp;

src/target/source/codegen_webgpu.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ class CodeGenWebGPU final : public CodeGenC {
6363
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
6464
void VisitExpr_(const BufferLoadNode* op, std::ostream& os) final; // NOLINT(*)
6565
void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*)
66-
void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*)
66+
void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
67+
void VisitExpr_(const LetNode* op, std::ostream& os) final; // NOLINT(*)
6768
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)
6869
void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*)
6970

0 commit comments

Comments
 (0)