Skip to content

Commit 3c6ca5d

Browse files
authored
[Bugfix][Relax] Set purity=false for LazySetOutput (#17119)
The `relax.transform.LazySetOutput` transformation updates a Relax function to produce output from a `fset_output` callback. In the initial implementation, the `fset_output` was marked as a pure function, which allowed it to be erroneously removed from a function. This commit updates the `relax::FuncStructInfo` used to annotate `fset_output`, marking it as an impure function.
1 parent 73cad19 commit 3c6ca5d

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

src/relax/transform/lazy_transform_params.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class LazyOutputMutator : public ExprMutator {
149149

150150
Var fset_output("fset_output",
151151
FuncStructInfo({PrimStructInfo(DataType::Int(64)), ObjectStructInfo()},
152-
TupleStructInfo(Array<StructInfo>{})));
152+
TupleStructInfo(Array<StructInfo>{}), /* purity = */ false));
153153
plan_ = FunctionPlan{std::move(output_lookup), fset_output};
154154

155155
std::optional<int64_t> num_input_params = GetNumInputParams(func);
@@ -189,6 +189,7 @@ class LazyOutputMutator : public ExprMutator {
189189
auto write_ptr = node.CopyOnWrite();
190190
write_ptr->params = new_params;
191191
write_ptr->body = new_body;
192+
write_ptr->is_pure = false;
192193
}
193194
if (num_input_params.has_value()) {
194195
node = WithAttr(node, attr::kNumInput, Integer(num_input_params.value() + 1));

tests/python/relax/test_transform_lazy_transform_params.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,11 +1002,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl
10021002

10031003
@I.ir_module
10041004
class Expected:
1005-
@R.function
1005+
@R.function(pure=False)
10061006
def transform_params(
10071007
A: R.Tensor([16, 16], "float32"),
10081008
B: R.Tensor([16, 16], "float32"),
1009-
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
1009+
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False),
10101010
):
10111011
C = R.multiply(A, R.const(2, "float32"))
10121012
fset_output(R.prim_value(1), C)
@@ -1036,11 +1036,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl
10361036

10371037
@I.ir_module
10381038
class Expected:
1039-
@R.function
1039+
@R.function(pure=False)
10401040
def transform_params(
10411041
A: R.Tensor([16, 16], "float32"),
10421042
B: R.Tensor([16, 16], "float32"),
1043-
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
1043+
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False),
10441044
):
10451045
fset_output(R.prim_value(1), B)
10461046
C = R.multiply(A, R.const(2, "float32"))
@@ -1070,10 +1070,10 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl
10701070

10711071
@I.ir_module
10721072
class Expected:
1073-
@R.function
1073+
@R.function(pure=False)
10741074
def transform_params(
10751075
A: R.Tensor([16, 16], "float32"),
1076-
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
1076+
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False),
10771077
B: R.Tensor([16, 16], "float32"),
10781078
):
10791079
R.func_attr({"num_input": 2})
@@ -1105,11 +1105,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl
11051105

11061106
@I.ir_module
11071107
class Expected:
1108-
@R.function
1108+
@R.function(pure=False)
11091109
def transform_params(
11101110
A: R.Tensor([16, 16], "float32"),
11111111
B: R.Tensor([16, 16], "float32"),
1112-
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
1112+
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False),
11131113
):
11141114
C = R.multiply(A, R.const(2, "float32"))
11151115
D = R.add(C, B)
@@ -1140,11 +1140,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl
11401140

11411141
@I.ir_module
11421142
class Expected:
1143-
@R.function
1143+
@R.function(pure=False)
11441144
def transform_params(
11451145
A: R.Tensor([16, 16], "float32"),
11461146
B: R.Tensor([16, 16], "float32"),
1147-
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
1147+
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False),
11481148
):
11491149
C = R.multiply(A, R.const(2, "float32"))
11501150
fset_output(R.prim_value(0), C)
@@ -1171,11 +1171,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl
11711171

11721172
@I.ir_module
11731173
class Expected:
1174-
@R.function
1174+
@R.function(pure=False)
11751175
def transform_params(
11761176
A: R.Tensor([16, 16], "float32"),
11771177
B: R.Tensor([16, 16], "float32"),
1178-
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
1178+
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False),
11791179
):
11801180
C = R.multiply(A, R.const(2, "float32"))
11811181
D = R.add(C, B)

0 commit comments

Comments
 (0)