Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 28 additions & 41 deletions src/script/ir_builder/relax/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,36 +118,23 @@ void BlockFrameNode::EnterWithScope() {
}
}

class DataflowBlockRewriter : public tvm::relax::ExprMutator {
class VarReplacer : public tvm::relax::ExprMutator {
public:
static tvm::relax::DataflowBlock Rewrite(const tvm::relax::DataflowBlock& block,
const Array<tvm::relax::Var>& output_vars) {
DataflowBlockRewriter rewriter(output_vars);
return Downcast<tvm::relax::DataflowBlock>(rewriter.VisitBindingBlock(block));
explicit VarReplacer(
std::unordered_map<tvm::relax::Id, tvm::relax::Var, ObjectPtrHash, ObjectPtrEqual>
var_remap) {
var_remap_ = std::move(var_remap);
}

private:
explicit DataflowBlockRewriter(const Array<tvm::relax::Var>& output_vars) {
for (const tvm::relax::Var& var : output_vars) {
output_var_set_.insert(var.get());
}
}

tvm::relax::Var VisitVarDef_(const tvm::relax::DataflowVarNode* op) final {
auto it = output_var_set_.find(op);
if (it != output_var_set_.end()) {
// Rewrite dataflow vars to global vars
auto n = make_object<tvm::relax::VarNode>(*op);
tvm::relax::Var new_var(n);
this->var_remap_[op->vid] = new_var;
return new_var;
tvm::relax::Var VisitVarDef(const tvm::relax::Var& var) override {
// ExprMutator only applies var_remap_ at usage sites. This
// applies var_remap_ at each definition site as well.
if (auto it = var_remap_.find(var->vid); it != var_remap_.end()) {
return it->second;
} else {
return GetRef<tvm::relax::Var>(op);
return var;
}
}

private:
std::unordered_set<const tvm::relax::VarNode*> output_var_set_;
};

void BlockFrameNode::ExitWithScope() {
Expand All @@ -164,25 +151,27 @@ void BlockFrameNode::ExitWithScope() {

// Step 3. Rewrite the dataflow block.
if (is_dataflow) {
// Step 3.1. Rewrite block binding
block = DataflowBlockRewriter::Rewrite(Downcast<tvm::relax::DataflowBlock>(block), output_vars);

// Step 3.2. Collect global vars' reference in bindings
Map<tvm::relax::Id, tvm::relax::Var> new_global_vars;
for (const tvm::relax::Binding& binding : block->bindings) {
if (!binding->var->IsInstance<tvm::relax::DataflowVarNode>()) {
new_global_vars.Set(binding->var->vid, binding->var);
}
// Step 3.0. Define a map to replace variables
Array<tvm::relax::Var> new_output_vars;
std::unordered_map<tvm::relax::Id, tvm::relax::Var, ObjectPtrHash, ObjectPtrEqual> var_remap;
for (const auto& output_var : output_vars) {
tvm::relax::Var new_output_var(output_var->name_hint(), GetStructInfo(output_var));
new_output_vars.push_back(new_output_var);
var_remap[output_var->vid] = new_output_var;
}
VarReplacer mutator(std::move(var_remap));

// Step 3.1. Rewrite block binding
block = mutator.VisitBindingBlock(block);

// Step 3.3. Rewrite output vars
Array<tvm::relax::Var> new_output_vars;
for (const auto& var : output_vars) {
auto it = new_global_vars.find(var->vid);
ICHECK(it != new_global_vars.end());
new_output_vars.push_back((*it).second);
}
output_vars = std::move(new_output_vars);

// Step 3.4 Rewrite usage of output var, if any
auto function = FindFunctionFrame("R.dataflow()");
if (function->output.defined()) {
function->output = mutator.VisitExpr(function->output.value());
}
}

// Step 3. Get the last frame from the IRBuilder frame stack.
Expand All @@ -196,8 +185,6 @@ void BlockFrameNode::ExitWithScope() {

// Step 5. Push the block frame into the corresponding field of the last frame.
if (const auto* seq_frame = last_frame.as<SeqExprFrameNode>()) {
ICHECK(!seq_frame->output.defined())
<< "The function is not expected to have output values when emitting blocks.";
auto frame = GetRef<SeqExprFrame>(seq_frame);
frame->binding_blocks.push_back(block);
} else {
Expand Down
23 changes: 16 additions & 7 deletions src/script/ir_builder/relax/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,20 +117,29 @@ void FuncRetValue(const tvm::relax::Expr& value) {
const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder();
tvm::relax::Expr normalized_value = block_builder->Normalize(value);

IRBuilder ir_builder = IRBuilder::Current();

// Step 1. The current Relax TVMScript syntax only allows function return appearing at the end of
// a function body. Therefore if there is any unended block frame when dealing with function
// return, we should end the block frame.
Optional<BlockFrame> block_frame = IRBuilder::Current()->GetLastFrame<BlockFrame>();
if (block_frame.defined()) {
block_frame.value()->ExitWithScope();
ICHECK(!IRBuilder::Current()->FindFrame<BlockFrame>())
<< "ValueError: Relax functions don't support return in true/false branch of If Node.";

if (auto opt = ir_builder->GetLastFrame<BlockFrame>()) {
auto block_frame = opt.value();
for (const auto& var : tvm::relax::FreeVars(normalized_value)) {
if (var->IsInstance<tvm::relax::DataflowVarNode>()) {
block_frame->output_vars.push_back(var);
}
}
}
// Step 2. Add the output value to the function frame.
FunctionFrame frame = FindFunctionFrame("return");
CHECK(!frame->output.defined())
<< "ValueError: Relax functions don't support multiple return statement. Please make sure "
"the return statement appears at the end of function.";
<< "ValueError: "
<< "Relax functions do not support multiple return statement. "
<< "However, return of " << normalized_value << " occurred after a return of "
<< frame->output << ". "
<< "Please make sure function only has a single return statement, "
<< "which appears at the end of function.";

frame->output = std::move(normalized_value);
}
Expand Down
31 changes: 31 additions & 0 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2410,5 +2410,36 @@ def inferred_sinfo(
tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo)


def test_return_from_dataflow_block():
"""Return statements imply

The `R.output` statement in a `R.dataflow()` block marks a
variable that should be a `relax.Var` instead of a
`relax.DataflowVar`, allowing it to be used outside of the
`DataflowBlock` that defined it. A relax function's output is not
part of any binding, and must not contain any `DataflowVar`, so
these are exposed implicitly.

"""

@R.function(private=True)
def output_then_return(A: R.Tensor([16], "float16")):
with R.dataflow():
B = R.add(A, A)
C = R.multiply(B, B)
R.output(C)

return C

@R.function(private=True)
def return_inside_dataflow(A: R.Tensor([16], "float16")):
with R.dataflow():
B = R.add(A, A)
C = R.multiply(B, B)
return C

tvm.ir.assert_structural_equal(output_then_return, return_inside_dataflow)


if __name__ == "__main__":
tvm.testing.main()