|
45 | 45 | #include <vector> |
46 | 46 |
|
47 | 47 | #include "../../target/source/codegen_source_base.h" |
| 48 | +#include "../../tir/transforms/ir_utils.h" |
48 | 49 | #include "../op/annotation/annotation.h" |
49 | 50 | #include "../op/call/call.h" |
50 | 51 | #include "../op/memory/device_copy.h" |
@@ -505,18 +506,34 @@ class AOTExecutorCodegen : public MixedModeVisitor { |
505 | 506 | * copy-on-write fashion. |
506 | 507 | */ |
507 | 508 | void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) { |
| 509 | + std::vector<tir::Stmt> let_nest; |
| 510 | + |
508 | 511 | // Define intermediate DLTensor to load/store the data |
509 | 512 | tir::Buffer tmp_read = |
510 | 513 | tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_read"); |
511 | 514 | tir::Buffer tmp_write = |
512 | 515 | tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_write"); |
513 | | - te::Var loop_idx("i", DataType::Int(32)); |
514 | | - auto retval_i = tir::BufferLoad(tmp_read, {loop_idx}); |
| 516 | + |
| 517 | + // Re-use in/out as the buffer var, if possible |
| 518 | + if (auto opt = out.as<tir::Var>()) { |
| 519 | + tmp_write.CopyOnWrite()->data = opt.value(); |
| 520 | + } else { |
| 521 | + let_nest.push_back(tir::LetStmt(tmp_write->data, out, tir::Evaluate(0))); |
| 522 | + } |
| 523 | + if (auto opt = in.as<tir::Var>()) { |
| 524 | + tmp_read.CopyOnWrite()->data = opt.value(); |
| 525 | + } else { |
| 526 | + let_nest.push_back(tir::LetStmt(tmp_read->data, in, tir::Evaluate(0))); |
| 527 | + } |
| 528 | + |
515 | 529 | // Copy the variable from the input to the output |
516 | | - tir::Stmt copy = tir::For( |
517 | | - loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), tir::ForKind::kSerial, |
518 | | - tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx})); |
519 | | - stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy)); |
| 530 | + te::Var loop_idx("i", DataType::Int(32)); |
| 531 | + tir::Stmt copy = tir::BufferStore(tmp_write, tir::BufferLoad(tmp_read, {loop_idx}), {loop_idx}); |
| 532 | + copy = tir::For(loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), |
| 533 | + tir::ForKind::kSerial, copy); |
| 534 | + copy = tir::MergeNest(let_nest, copy); |
| 535 | + |
| 536 | + stmts_.push_back(copy); |
520 | 537 | } |
521 | 538 |
|
522 | 539 | /* |
|
0 commit comments