Skip to content

Commit ff884b6

Browse files
authored
[Relax][Transform] Handle tuple return in RemoveUnusedOutputs (#17253)
* [Relax][Transform] Handle tuple return in RemoveUnusedOutputs Prior to this commit, the `relax.transform.RemoveUnusedOutputs` pass only marked a tuple element as used if it occurred in a `TupleGetItem` node. This ignored use cases where a tuple is used as an aggregate object, such as returning a tuple from a function. This would collect incorrect results for a Relax function that calls a subroutine, receives a tuple as the return value of the subroutine, then returns that tuple. This commit updates `RemoveUnusedOutputs` to look for usage of a tuple object, not just for usage in `TupleGetItem`. Closes #17247
1 parent ec28b67 commit ff884b6

File tree

2 files changed

+59
-20
lines changed

2 files changed

+59
-20
lines changed

src/relax/transform/remove_unused_outputs.cc

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -92,29 +92,48 @@ class PartialTupleUsageCollector : ExprVisitor {
9292
}
9393

9494
void VisitExpr_(const TupleGetItemNode* op) override {
95-
Expr tuple = UnwrapBindings(op->tuple);
96-
97-
if (auto call = tuple.as<CallNode>()) {
98-
if (auto opt_callee = call->op.as<GlobalVar>()) {
99-
auto callee = opt_callee.value();
100-
if (auto it = output_usage_mask_.find(callee); it != output_usage_mask_.end()) {
101-
auto& used_indices = it->second;
102-
103-
CHECK_GE(op->index, 0) << "IndexError: "
104-
<< "Indices for TupleGetItem must be non-negative, "
105-
<< "but expression " << GetRef<Expr>(op)
106-
<< " uses a tuple index of " << op->index;
107-
size_t index = op->index;
108-
109-
CHECK_LT(index, used_indices.size())
110-
<< "IndexError: "
111-
<< "Indices for TupleGetItem must be less than the size of the tuple, "
112-
<< "but expression " << GetRef<Expr>(op) << " uses a tuple index of " << op->index
113-
<< " for a tuple of size " << used_indices.size();
114-
used_indices[index] = true;
95+
if (auto* usage_mask_ptr = GetCalleeUsageMask(op->tuple)) {
96+
auto& used_indices = *usage_mask_ptr;
97+
98+
CHECK_GE(op->index, 0) << "IndexError: "
99+
<< "Indices for TupleGetItem must be non-negative, "
100+
<< "but expression " << GetRef<Expr>(op) << " uses a tuple index of "
101+
<< op->index;
102+
size_t index = op->index;
103+
104+
CHECK_LT(index, used_indices.size())
105+
<< "IndexError: "
106+
<< "Indices for TupleGetItem must be less than the size of the tuple, "
107+
<< "but expression " << GetRef<Expr>(op) << " uses a tuple index of " << op->index
108+
<< " for a tuple of size " << used_indices.size();
109+
used_indices[index] = true;
110+
}
111+
}
112+
113+
void VisitExpr_(const VarNode* op) override {
114+
if (auto* usage_mask_ptr = GetCalleeUsageMask(GetRef<Var>(op))) {
115+
auto& usage_mask = *usage_mask_ptr;
116+
for (size_t i = 0; i < usage_mask.size(); i++) {
117+
usage_mask[i] = true;
118+
}
119+
}
120+
}
121+
122+
std::vector<bool>* GetCalleeUsageMask(Expr expr) {
123+
if (!expr->struct_info_.as<TupleStructInfoNode>()) {
124+
return nullptr;
125+
}
126+
127+
expr = UnwrapBindings(expr);
128+
if (auto call = expr.as<CallNode>()) {
129+
if (auto callee = call->op.as<GlobalVar>()) {
130+
if (auto it = output_usage_mask_.find(callee.value()); it != output_usage_mask_.end()) {
131+
return &it->second;
115132
}
116133
}
117134
}
135+
136+
return nullptr;
118137
}
119138

120139
Expr UnwrapBindings(Expr expr) const {

tests/python/relax/test_transform_remove_unused_outputs.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,5 +119,25 @@ def func() -> R.Tuple([R.Tensor([16, 16], "int32"), R.Tensor([32, 32], "int32")]
119119
return (A, C)
120120

121121

122+
class TestReturnTuple(BaseCompare):
123+
@I.ir_module
124+
class Before:
125+
@R.function
126+
def main(A: R.Tensor([16, 16], "int32")):
127+
B = R.add(A, A)
128+
out_tuple = Before.func(B)
129+
return out_tuple
130+
131+
@R.function(private=True)
132+
def func(
133+
B: R.Tensor([16, 16], "int32")
134+
) -> R.Tuple(R.Tensor([16, 16], "int32"), R.Tensor([16, 16], "int32")):
135+
C = R.multiply(B, B)
136+
D = R.add(B, B)
137+
return (C, D)
138+
139+
Expected = Before
140+
141+
122142
if __name__ == "__main__":
123143
tvm.testing.main()

0 commit comments

Comments
 (0)