diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index 3153c0770e38..faf6bd6466ad 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -118,36 +118,23 @@ void BlockFrameNode::EnterWithScope() { } } -class DataflowBlockRewriter : public tvm::relax::ExprMutator { +class VarReplacer : public tvm::relax::ExprMutator { public: - static tvm::relax::DataflowBlock Rewrite(const tvm::relax::DataflowBlock& block, - const Array& output_vars) { - DataflowBlockRewriter rewriter(output_vars); - return Downcast(rewriter.VisitBindingBlock(block)); + explicit VarReplacer( + std::unordered_map + var_remap) { + var_remap_ = std::move(var_remap); } - private: - explicit DataflowBlockRewriter(const Array& output_vars) { - for (const tvm::relax::Var& var : output_vars) { - output_var_set_.insert(var.get()); - } - } - - tvm::relax::Var VisitVarDef_(const tvm::relax::DataflowVarNode* op) final { - auto it = output_var_set_.find(op); - if (it != output_var_set_.end()) { - // Rewrite dataflow vars to global vars - auto n = make_object(*op); - tvm::relax::Var new_var(n); - this->var_remap_[op->vid] = new_var; - return new_var; + tvm::relax::Var VisitVarDef(const tvm::relax::Var& var) override { + // ExprMutator only applies var_remap_ at usage sites. This + // applies var_remap_ at each definition site as well. + if (auto it = var_remap_.find(var->vid); it != var_remap_.end()) { + return it->second; } else { - return GetRef(op); + return var; } } - - private: - std::unordered_set output_var_set_; }; void BlockFrameNode::ExitWithScope() { @@ -164,25 +151,27 @@ void BlockFrameNode::ExitWithScope() { // Step 3. Rewrite the dataflow block. if (is_dataflow) { - // Step 3.1. Rewrite block binding - block = DataflowBlockRewriter::Rewrite(Downcast(block), output_vars); - - // Step 3.2. Collect global vars' reference in bindings - Map new_global_vars; - for (const tvm::relax::Binding& binding : block->bindings) { - if (!binding->var->IsInstance()) { - new_global_vars.Set(binding->var->vid, binding->var); - } + // Step 3.0. Define a map to replace variables + Array new_output_vars; + std::unordered_map var_remap; + for (const auto& output_var : output_vars) { + tvm::relax::Var new_output_var(output_var->name_hint(), GetStructInfo(output_var)); + new_output_vars.push_back(new_output_var); + var_remap[output_var->vid] = new_output_var; } + VarReplacer mutator(std::move(var_remap)); + + // Step 3.1. Rewrite block binding + block = mutator.VisitBindingBlock(block); // Step 3.3. Rewrite output vars - Array new_output_vars; - for (const auto& var : output_vars) { - auto it = new_global_vars.find(var->vid); - ICHECK(it != new_global_vars.end()); - new_output_vars.push_back((*it).second); - } output_vars = std::move(new_output_vars); + + // Step 3.4 Rewrite usage of output var, if any + auto function = FindFunctionFrame("R.dataflow()"); + if (function->output.defined()) { + function->output = mutator.VisitExpr(function->output.value()); + } } // Step 3. Get the last frame from the IRBuilder frame stack. @@ -196,8 +185,6 @@ void BlockFrameNode::ExitWithScope() { // Step 5. Push the block frame into the corresponding field of the last frame. if (const auto* seq_frame = last_frame.as()) { - ICHECK(!seq_frame->output.defined()) - << "The function is not expected to have output values when emitting blocks."; auto frame = GetRef(seq_frame); frame->binding_blocks.push_back(block); } else { diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 453c7fdb5522..b2e75d0c3698 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -117,20 +117,29 @@ void FuncRetValue(const tvm::relax::Expr& value) { const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); tvm::relax::Expr normalized_value = block_builder->Normalize(value); + IRBuilder ir_builder = IRBuilder::Current(); + // Step 1. The current Relax TVMScript syntax only allows function return appearing at the end of // a function body. Therefore if there is any unended block frame when dealing with function // return, we should end the block frame. - Optional block_frame = IRBuilder::Current()->GetLastFrame(); - if (block_frame.defined()) { - block_frame.value()->ExitWithScope(); - ICHECK(!IRBuilder::Current()->FindFrame()) - << "ValueError: Relax functions don't support return in true/false branch of If Node."; + + if (auto opt = ir_builder->GetLastFrame()) { + auto block_frame = opt.value(); + for (const auto& var : tvm::relax::FreeVars(normalized_value)) { + if (var->IsInstance()) { + block_frame->output_vars.push_back(var); + } + } } // Step 2. Add the output value to the function frame. FunctionFrame frame = FindFunctionFrame("return"); CHECK(!frame->output.defined()) - << "ValueError: Relax functions don't support multiple return statement. Please make sure " - "the return statement appears at the end of function."; + << "ValueError: " + << "Relax functions do not support multiple return statement. " + << "However, return of " << normalized_value << " occurred after a return of " + << frame->output << ". " + << "Please make sure function only has a single return statement, " + << "which appears at the end of function."; frame->output = std::move(normalized_value); } diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index fd465f320191..fa62d1484893 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -2410,5 +2410,36 @@ def inferred_sinfo( tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) +def test_return_from_dataflow_block(): + """Return statements imply + + The `R.output` statement in a `R.dataflow()` block marks a + variable that should be a `relax.Var` instead of a + `relax.DataflowVar`, allowing it to be used outside of the + `DataflowBlock` that defined it. A relax function's output is not + part of any binding, and must not contain any `DataflowVar`, so + these are exposed implicitly. + + """ + + @R.function(private=True) + def output_then_return(A: R.Tensor([16], "float16")): + with R.dataflow(): + B = R.add(A, A) + C = R.multiply(B, B) + R.output(C) + + return C + + @R.function(private=True) + def return_inside_dataflow(A: R.Tensor([16], "float16")): + with R.dataflow(): + B = R.add(A, A) + C = R.multiply(B, B) + return C + + tvm.ir.assert_structural_equal(output_then_return, return_inside_dataflow) + + if __name__ == "__main__": tvm.testing.main()