Skip to content

Commit 870ba06

Browse files
committed
Avoid Var-to-Var Let binding in AOTExecutorCodegen
Prior to #14951, these can have erroneous simplifications when used in buffer definitions.
1 parent 0d457a5 commit 870ba06

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

src/relay/backend/aot_executor_codegen.cc

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include <vector>
4646

4747
#include "../../target/source/codegen_source_base.h"
48+
#include "../../tir/transforms/ir_utils.h"
4849
#include "../op/annotation/annotation.h"
4950
#include "../op/call/call.h"
5051
#include "../op/memory/device_copy.h"
@@ -505,18 +506,34 @@ class AOTExecutorCodegen : public MixedModeVisitor {
505506
* copy-on-write fashion.
506507
*/
507508
void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) {
509+
std::vector<tir::Stmt> let_nest;
510+
508511
// Define intermediate DLTensor to load/store the data
509512
tir::Buffer tmp_read =
510513
tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_read");
511514
tir::Buffer tmp_write =
512515
tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_write");
513-
te::Var loop_idx("i", DataType::Int(32));
514-
auto retval_i = tir::BufferLoad(tmp_read, {loop_idx});
516+
517+
// Re-use in/out as the buffer var, if possible
518+
if (auto opt = out.as<tir::Var>()) {
519+
tmp_write.CopyOnWrite()->data = opt.value();
520+
} else {
521+
let_nest.push_back(tir::LetStmt(tmp_write->data, out, tir::Evaluate(0)));
522+
}
523+
if (auto opt = in.as<tir::Var>()) {
524+
tmp_read.CopyOnWrite()->data = opt.value();
525+
} else {
526+
let_nest.push_back(tir::LetStmt(tmp_read->data, in, tir::Evaluate(0)));
527+
}
528+
515529
// Copy the variable from the input to the output
516-
tir::Stmt copy = tir::For(
517-
loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), tir::ForKind::kSerial,
518-
tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx}));
519-
stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy));
530+
te::Var loop_idx("i", DataType::Int(32));
531+
tir::Stmt copy = tir::BufferStore(tmp_write, tir::BufferLoad(tmp_read, {loop_idx}), {loop_idx});
532+
copy = tir::For(loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()),
533+
tir::ForKind::kSerial, copy);
534+
copy = tir::MergeNest(let_nest, copy);
535+
536+
stmts_.push_back(copy);
520537
}
521538

522539
/*

0 commit comments

Comments
 (0)