Skip to content

Commit a242046

Browse files
authored
[TVMScript][Relax] Allow return statement in DataflowBlock (#17131)
Prior to this commit, TVMScript required the return value of a Relax to be specified outside of any `with R.dataflow()` blocks. This resulted in a common pattern, where the return value of a function was first called with `R.output(ret_value)`, to mark `ret_value` as a `tvm::relax::Var` instead of a `tvm::relax::DataflowVar`, followed immediately by a `return ret_value` statement. This commit updates the TVMScript parser to allow a `return` statement inside a `with R.dataflow()` block. This is syntactic sugar that is equivalent to calling `R.output`, followed by a `return`. With this change, the following two TVMScript examples are now equivalent. (Prior to this change, the `return_inside_dataflow` example would raise an error during parsing.) ```python @R.function(private=True) def output_then_return(A: R.Tensor): with R.dataflow(): B = R.add(A, A) C = R.multiply(B, B) R.output(C) return C @R.function(private=True) def return_inside_dataflow(A: R.Tensor): with R.dataflow(): B = R.add(A, A) C = R.multiply(B, B) return C ```
1 parent ff8e416 commit a242046

File tree

3 files changed

+75
-48
lines changed

3 files changed

+75
-48
lines changed

src/script/ir_builder/relax/frame.cc

Lines changed: 28 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -118,36 +118,23 @@ void BlockFrameNode::EnterWithScope() {
118118
}
119119
}
120120

121-
class DataflowBlockRewriter : public tvm::relax::ExprMutator {
121+
class VarReplacer : public tvm::relax::ExprMutator {
122122
public:
123-
static tvm::relax::DataflowBlock Rewrite(const tvm::relax::DataflowBlock& block,
124-
const Array<tvm::relax::Var>& output_vars) {
125-
DataflowBlockRewriter rewriter(output_vars);
126-
return Downcast<tvm::relax::DataflowBlock>(rewriter.VisitBindingBlock(block));
123+
explicit VarReplacer(
124+
std::unordered_map<tvm::relax::Id, tvm::relax::Var, ObjectPtrHash, ObjectPtrEqual>
125+
var_remap) {
126+
var_remap_ = std::move(var_remap);
127127
}
128128

129-
private:
130-
explicit DataflowBlockRewriter(const Array<tvm::relax::Var>& output_vars) {
131-
for (const tvm::relax::Var& var : output_vars) {
132-
output_var_set_.insert(var.get());
133-
}
134-
}
135-
136-
tvm::relax::Var VisitVarDef_(const tvm::relax::DataflowVarNode* op) final {
137-
auto it = output_var_set_.find(op);
138-
if (it != output_var_set_.end()) {
139-
// Rewrite dataflow vars to global vars
140-
auto n = make_object<tvm::relax::VarNode>(*op);
141-
tvm::relax::Var new_var(n);
142-
this->var_remap_[op->vid] = new_var;
143-
return new_var;
129+
tvm::relax::Var VisitVarDef(const tvm::relax::Var& var) override {
130+
// ExprMutator only applies var_remap_ at usage sites. This
131+
// applies var_remap_ at each definition site as well.
132+
if (auto it = var_remap_.find(var->vid); it != var_remap_.end()) {
133+
return it->second;
144134
} else {
145-
return GetRef<tvm::relax::Var>(op);
135+
return var;
146136
}
147137
}
148-
149-
private:
150-
std::unordered_set<const tvm::relax::VarNode*> output_var_set_;
151138
};
152139

153140
void BlockFrameNode::ExitWithScope() {
@@ -164,25 +151,27 @@ void BlockFrameNode::ExitWithScope() {
164151

165152
// Step 3. Rewrite the dataflow block.
166153
if (is_dataflow) {
167-
// Step 3.1. Rewrite block binding
168-
block = DataflowBlockRewriter::Rewrite(Downcast<tvm::relax::DataflowBlock>(block), output_vars);
169-
170-
// Step 3.2. Collect global vars' reference in bindings
171-
Map<tvm::relax::Id, tvm::relax::Var> new_global_vars;
172-
for (const tvm::relax::Binding& binding : block->bindings) {
173-
if (!binding->var->IsInstance<tvm::relax::DataflowVarNode>()) {
174-
new_global_vars.Set(binding->var->vid, binding->var);
175-
}
154+
// Step 3.0. Define a map to replace variables
155+
Array<tvm::relax::Var> new_output_vars;
156+
std::unordered_map<tvm::relax::Id, tvm::relax::Var, ObjectPtrHash, ObjectPtrEqual> var_remap;
157+
for (const auto& output_var : output_vars) {
158+
tvm::relax::Var new_output_var(output_var->name_hint(), GetStructInfo(output_var));
159+
new_output_vars.push_back(new_output_var);
160+
var_remap[output_var->vid] = new_output_var;
176161
}
162+
VarReplacer mutator(std::move(var_remap));
163+
164+
// Step 3.1. Rewrite block binding
165+
block = mutator.VisitBindingBlock(block);
177166

178167
// Step 3.3. Rewrite output vars
179-
Array<tvm::relax::Var> new_output_vars;
180-
for (const auto& var : output_vars) {
181-
auto it = new_global_vars.find(var->vid);
182-
ICHECK(it != new_global_vars.end());
183-
new_output_vars.push_back((*it).second);
184-
}
185168
output_vars = std::move(new_output_vars);
169+
170+
// Step 3.4 Rewrite usage of output var, if any
171+
auto function = FindFunctionFrame("R.dataflow()");
172+
if (function->output.defined()) {
173+
function->output = mutator.VisitExpr(function->output.value());
174+
}
186175
}
187176

188177
// Step 3. Get the last frame from the IRBuilder frame stack.
@@ -196,8 +185,6 @@ void BlockFrameNode::ExitWithScope() {
196185

197186
// Step 5. Push the block frame into the corresponding field of the last frame.
198187
if (const auto* seq_frame = last_frame.as<SeqExprFrameNode>()) {
199-
ICHECK(!seq_frame->output.defined())
200-
<< "The function is not expected to have output values when emitting blocks.";
201188
auto frame = GetRef<SeqExprFrame>(seq_frame);
202189
frame->binding_blocks.push_back(block);
203190
} else {

src/script/ir_builder/relax/ir.cc

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,20 +117,29 @@ void FuncRetValue(const tvm::relax::Expr& value) {
117117
const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder();
118118
tvm::relax::Expr normalized_value = block_builder->Normalize(value);
119119

120+
IRBuilder ir_builder = IRBuilder::Current();
121+
120122
// Step 1. The current Relax TVMScript syntax only allows function return appearing at the end of
121123
// a function body. Therefore if there is any unended block frame when dealing with function
122124
// return, we should end the block frame.
123-
Optional<BlockFrame> block_frame = IRBuilder::Current()->GetLastFrame<BlockFrame>();
124-
if (block_frame.defined()) {
125-
block_frame.value()->ExitWithScope();
126-
ICHECK(!IRBuilder::Current()->FindFrame<BlockFrame>())
127-
<< "ValueError: Relax functions don't support return in true/false branch of If Node.";
125+
126+
if (auto opt = ir_builder->GetLastFrame<BlockFrame>()) {
127+
auto block_frame = opt.value();
128+
for (const auto& var : tvm::relax::FreeVars(normalized_value)) {
129+
if (var->IsInstance<tvm::relax::DataflowVarNode>()) {
130+
block_frame->output_vars.push_back(var);
131+
}
132+
}
128133
}
129134
// Step 2. Add the output value to the function frame.
130135
FunctionFrame frame = FindFunctionFrame("return");
131136
CHECK(!frame->output.defined())
132-
<< "ValueError: Relax functions don't support multiple return statement. Please make sure "
133-
"the return statement appears at the end of function.";
137+
<< "ValueError: "
138+
<< "Relax functions do not support multiple return statement. "
139+
<< "However, return of " << normalized_value << " occurred after a return of "
140+
<< frame->output << ". "
141+
<< "Please make sure function only has a single return statement, "
142+
<< "which appears at the end of function.";
134143

135144
frame->output = std::move(normalized_value);
136145
}

tests/python/relax/test_tvmscript_parser.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2410,5 +2410,36 @@ def inferred_sinfo(
24102410
tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo)
24112411

24122412

2413+
def test_return_from_dataflow_block():
2414+
"""Return statements imply
2415+
2416+
The `R.output` statement in a `R.dataflow()` block marks a
2417+
variable that should be a `relax.Var` instead of a
2418+
`relax.DataflowVar`, allowing it to be used outside of the
2419+
`DataflowBlock` that defined it. A relax function's output is not
2420+
part of any binding, and must not contain any `DataflowVar`, so
2421+
these are exposed implicitly.
2422+
2423+
"""
2424+
2425+
@R.function(private=True)
2426+
def output_then_return(A: R.Tensor([16], "float16")):
2427+
with R.dataflow():
2428+
B = R.add(A, A)
2429+
C = R.multiply(B, B)
2430+
R.output(C)
2431+
2432+
return C
2433+
2434+
@R.function(private=True)
2435+
def return_inside_dataflow(A: R.Tensor([16], "float16")):
2436+
with R.dataflow():
2437+
B = R.add(A, A)
2438+
C = R.multiply(B, B)
2439+
return C
2440+
2441+
tvm.ir.assert_structural_equal(output_then_return, return_inside_dataflow)
2442+
2443+
24132444
if __name__ == "__main__":
24142445
tvm.testing.main()

0 commit comments

Comments
 (0)