Skip to content

Commit ffd47e6

Browse files
committed
[SLM][Bugfix] Output debug functions as impure
Prior to this commit, debug functions were generated with `relax.call_pure_packed`. This resulted in unexpected behavior, as `nn.op.print_` can be optimized away as a pure function.r This commit updates debug functions to be generated as impure functions. This requires removing the `with bb.dataflow()` blocks in the SLM-to-relax conversions, as impure functions may not be used in a dataflow block. To restore dataflow blocks when legal, the `ConvertToDataflow` pass is applied.
1 parent 254e90a commit ffd47e6

File tree

11 files changed

+113
-62
lines changed

11 files changed

+113
-62
lines changed

include/tvm/relax/expr.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -983,16 +983,21 @@ class FunctionNode : public BaseFuncNode {
983983
class Function : public BaseFunc {
984984
public:
985985
TVM_DLL explicit Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
986-
bool is_pure = true, DictAttrs attrs = NullValue<DictAttrs>(),
987-
Span span = Span());
986+
bool is_pure, DictAttrs attrs = NullValue<DictAttrs>(),
987+
Span span = Span())
988+
: Function(params, body, ret_struct_info, Optional<Bool>(Bool(is_pure)), attrs, span) {}
989+
990+
TVM_DLL explicit Function(Array<Var> params, Expr body,
991+
Optional<StructInfo> ret_struct_info = NullOpt,
992+
Optional<Bool> is_pure = NullOpt,
993+
DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
988994

989995
/*!
990996
* \brief Mimics the constructor but without body Expr.
991-
* \note ret_struct_info is required, since it can not deduced by the body.
997+
* \note `ret_struct_info` and `is_pure` are required, since it can not deduced by the body.
992998
*/
993-
TVM_DLL static Function CreateEmpty(Array<Var> params, StructInfo ret_struct_info,
994-
bool is_pure = true, DictAttrs attrs = NullValue<DictAttrs>(),
995-
Span span = Span());
999+
TVM_DLL static Function CreateEmpty(Array<Var> params, StructInfo ret_struct_info, bool is_pure,
1000+
DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
9961001

9971002
TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
9981003
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);

python/tvm/relax/block_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -638,8 +638,8 @@ def emit_func_output(
638638
finally:
639639
self.end_scope()
640640

641-
# do not specify ret_struct_info and let constructor deduce
642-
# from seqe.struct_info
641+
# Do not specify ret_struct_info or purity, and let the
642+
# constructor deduce from seqe.struct_info.
643643
func = rx.Function(self._func._params, seqe)
644644
for key, value in self._func._attrs.items():
645645
func = func.with_attr(key, value)

python/tvm/relax/expr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,7 @@ def __init__(
887887
params: List[Var],
888888
body: Expr,
889889
ret_struct_info: Optional[StructInfo] = None,
890-
is_pure: Optional[bool] = True,
890+
is_pure: Optional[bool] = None,
891891
attrs: Optional[tvm.ir.DictAttrs] = None,
892892
span: Optional[Span] = None,
893893
) -> None:

python/tvm/relax/frontend/nn/exporter.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,7 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
117117
with self:
118118
if effects:
119119
with self.builder.function("_initialize_effect"):
120-
with self.builder.dataflow():
121-
outputs = _emit_effect_init(self.builder, effects)
120+
outputs = _emit_effect_init(self.builder, effects)
122121
self.builder.emit_func_output(outputs, params=[])
123122
for method_name, method_spec in zip(spec.method_names, spec.method_specs):
124123
params = _params() # Re-initialize so symbolic shapes not shared across methods
@@ -132,12 +131,12 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
132131
method_name,
133132
attrs={"num_input": len_args + len_effects}, # type: ignore
134133
):
135-
with self.builder.dataflow():
136-
outputs, inputs = _emit_method(self.builder, method_spec, params, effects)
134+
outputs, inputs = _emit_method(self.builder, method_spec, params, effects)
137135
self.builder.emit_func_output(outputs, inputs)
138136
mod = self.builder.finalize()
139137
assert rx.analysis.well_formed(mod)
140138

139+
mod = rx.transform.ConvertToDataflow(min_size=1)(mod)
141140
return mod, params, ext_mods
142141

143142

@@ -150,7 +149,7 @@ def _emit_effect_init(
150149
inits = effect.emit_init(prefix, builder)
151150
assert isinstance(inits, list)
152151
outputs.extend(inits)
153-
outputs = builder.emit_output(builder.emit(rx.Tuple(outputs)))
152+
outputs = builder.emit(rx.Tuple(outputs))
154153
return outputs
155154

156155

@@ -281,9 +280,9 @@ def _detuple(arg, var: rx.Var, builder: BlockBuilder):
281280
for _, effect in effects:
282281
effect_outputs.extend(effect.finalize())
283282
if effect_outputs and spec.effect_mode != "none":
284-
outputs = builder.emit_output(rx.Tuple([_unwrap_ret(outputs), rx.Tuple(effect_outputs)]))
283+
outputs = builder.emit(rx.Tuple([_unwrap_ret(outputs), rx.Tuple(effect_outputs)]))
285284
else:
286-
outputs = builder.emit_output(_unwrap_ret(outputs))
285+
outputs = builder.emit(_unwrap_ret(outputs))
287286
return outputs, inputs
288287

289288

python/tvm/relax/frontend/nn/op.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1897,15 +1897,14 @@ def debug_func(lineno: str, arg_0, arg_1, ...) -> None:
18971897
else:
18981898
raise TypeError(f"Unsupported type {type(arg)}")
18991899

1900+
func = rx.ExternFunc("vm.builtin.invoke_debug_func")
1901+
call = rx.Call(
1902+
func,
1903+
[io.effect, rx.StringImm(name), rx.StringImm(_line_info), *converted_args],
1904+
sinfo_args=[rx.ObjectStructInfo()],
1905+
)
19001906
io.effect = BlockBuilder.current().emit(
1901-
rx.call_pure_packed(
1902-
"vm.builtin.invoke_debug_func",
1903-
io.effect,
1904-
rx.StringImm(name),
1905-
rx.StringImm(_line_info),
1906-
*converted_args,
1907-
sinfo_args=[rx.ObjectStructInfo()],
1908-
),
1907+
call,
19091908
name_hint=io.effect.name_hint,
19101909
)
19111910

src/relax/ir/expr.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,8 @@ TVM_REGISTER_GLOBAL("relax.SeqExpr")
441441

442442
TVM_REGISTER_NODE_TYPE(FunctionNode);
443443

444-
Function::Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info, bool is_pure,
445-
DictAttrs attrs, Span span) {
444+
Function::Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
445+
Optional<Bool> is_pure_override, DictAttrs attrs, Span span) {
446446
// Set the function type.
447447
// For function, we take a conservative approach and require the function type
448448
// to be known at construction time.
@@ -473,6 +473,13 @@ Function::Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct
473473
ret_struct_info = body_sinfo;
474474
}
475475

476+
bool is_pure;
477+
if (is_pure_override.defined()) {
478+
is_pure = is_pure_override.value()->value;
479+
} else {
480+
is_pure = !ContainsImpureCall(body);
481+
}
482+
476483
FuncStructInfo func_sinfo(param_sinfo, ret_struct_info.value(), is_pure);
477484

478485
// set the fields
@@ -490,7 +497,7 @@ Function::Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct
490497

491498
TVM_REGISTER_GLOBAL("relax.Function")
492499
.set_body_typed([](Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
493-
bool is_pure, DictAttrs attrs, Span span) {
500+
Optional<Bool> is_pure, DictAttrs attrs, Span span) {
494501
return Function(params, body, ret_struct_info, is_pure, attrs, span);
495502
});
496503

tests/python/relax/test_frontend_nn_debug.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from tvm.relax.frontend.nn import op, spec
2525
from tvm.runtime import NDArray
2626

27+
from tvm.script import ir as I, relax as R
28+
2729

2830
def test_debug_print():
2931
class Layer(nn.Module):
@@ -42,6 +44,62 @@ def forward(self, x: nn.Tensor): # pylint: disable=invalid-name
4244
assert isinstance(y, torch.Tensor)
4345

4446

47+
def test_debug_print_well_formed():
48+
class Layer(nn.Module):
49+
def forward(self, state: nn.Tensor):
50+
state = state * 2.0
51+
op.print_(state)
52+
state = state * 2.0
53+
return state
54+
55+
forward_code = Layer.forward.__wrapped__.__code__
56+
debug_location = f"{forward_code.co_filename}:{forward_code.co_firstlineno+2}"
57+
58+
model, _ = Layer().export_tvm(
59+
spec={
60+
"forward": {"state": spec.Tensor([10, 5], dtype="float32")},
61+
},
62+
debug=True,
63+
)
64+
65+
@I.ir_module
66+
class Expected:
67+
@R.function
68+
def _initialize_effect() -> R.Tuple(R.Object):
69+
with R.dataflow():
70+
_io = R.null_value()
71+
gv = (_io,)
72+
R.output(gv)
73+
return gv
74+
75+
@R.function(pure=False)
76+
def forward(
77+
state: R.Tensor((10, 5), dtype="float32"), _io: R.Object
78+
) -> R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tuple(R.Object)):
79+
R.func_attr({"num_input": 2})
80+
with R.dataflow():
81+
mul = R.multiply(state, R.const(2, "float32"))
82+
R.output(mul)
83+
84+
_io1 = R.call_packed(
85+
"vm.builtin.invoke_debug_func",
86+
_io,
87+
R.str("vm.builtin.debug_print"),
88+
R.str(debug_location),
89+
mul,
90+
sinfo_args=(R.Object,),
91+
)
92+
93+
with R.dataflow():
94+
mul1 = R.multiply(mul, R.const(2, "float32"))
95+
gv1 = mul1, (_io1,)
96+
R.output(gv1)
97+
98+
return gv1
99+
100+
tvm.ir.assert_structural_equal(Expected, model)
101+
102+
45103
def test_debug_func():
46104
@tvm.register_func("testing.relax.frontend.nn.test_debug_func")
47105
def _debug( # pylint: disable=too-many-arguments
@@ -79,5 +137,4 @@ def forward(self, x: nn.Tensor, v: tir.Var): # pylint: disable=invalid-name
79137

80138

81139
if __name__ == "__main__":
82-
test_debug_print()
83-
test_debug_func()
140+
tvm.testing.main()

tests/python/relax/test_frontend_nn_extern_module.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,9 @@ def scalar_add(
9191
) -> R.Tensor((), dtype="float32"):
9292
R.func_attr({"num_input": 2})
9393
with R.dataflow():
94-
ext_scalar_add = R.call_dps_packed(
94+
gv = R.call_dps_packed(
9595
"ext_scalar_add", (a, b), out_sinfo=R.Tensor((), dtype="float32")
9696
)
97-
gv: R.Tensor((), dtype="float32") = ext_scalar_add
9897
R.output(gv)
9998
return gv
10099

@@ -107,10 +106,9 @@ def test_sym(
107106
z = T.int64()
108107
R.func_attr({"num_input": 2})
109108
with R.dataflow():
110-
ext_test_sym = R.call_dps_packed(
109+
gv1 = R.call_dps_packed(
111110
"ext_test_sym", (a, b), out_sinfo=R.Tensor((x, y, z, 9), dtype="float32")
112111
)
113-
gv1: R.Tensor((x, y, z, 9), dtype="float32") = ext_test_sym
114112
R.output(gv1)
115113
return gv1
116114

tests/python/relax/test_frontend_nn_op.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -532,8 +532,7 @@ def add_one(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), T_add: T.Buffer(
532532
def _initialize_effect() -> R.Tuple(R.Object):
533533
with R.dataflow():
534534
_io: R.Object = R.null_value()
535-
lv: R.Tuple(R.Object) = (_io,)
536-
gv: R.Tuple(R.Object) = lv
535+
gv = (_io,)
537536
R.output(gv)
538537
return gv
539538

@@ -605,8 +604,7 @@ def llama_fused_rope(var_qkv: T.handle, offset: T.int64, var_q: T.handle, var_k:
605604
def _initialize_effect() -> R.Tuple(R.Object):
606605
with R.dataflow():
607606
_io: R.Object = R.null_value()
608-
lv: R.Tuple(R.Object) = (_io,)
609-
gv: R.Tuple(R.Object) = lv
607+
gv = (_io,)
610608
R.output(gv)
611609
return gv
612610

@@ -693,8 +691,7 @@ def inplace_take(
693691
def _initialize_effect() -> R.Tuple(R.Object):
694692
with R.dataflow():
695693
_io: R.Object = R.null_value()
696-
lv: R.Tuple(R.Object) = (_io,)
697-
gv: R.Tuple(R.Object) = lv
694+
gv = (_io,)
698695
R.output(gv)
699696
return gv
700697

@@ -711,13 +708,12 @@ def test(
711708
R.func_attr({"num_input": 4})
712709
cls = Expected
713710
with R.dataflow():
714-
lv1 = R.call_tir(
711+
gv1 = R.call_tir(
715712
cls.inplace_take,
716713
(embedding_table, input_ids, embedding_dst),
717714
out_sinfo=R.Tensor((total_seq_len, hidden_size), dtype),
718715
tir_vars=R.shape([offset_1]),
719716
)
720-
gv1: R.Tensor((total_seq_len, hidden_size), dtype) = lv1
721717
R.output(gv1)
722718
return gv1
723719

@@ -766,8 +762,7 @@ def test(A: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="fl
766762
R.func_attr({"num_input": 1})
767763
cls = Expected
768764
with R.dataflow():
769-
lv = R.call_tir(cls.tir_func, (A,), out_sinfo=R.Tensor((16, 16), dtype="float32"))
770-
gv: R.Tensor((16, 16), dtype="float32") = lv
765+
gv = R.call_tir(cls.tir_func, (A,), out_sinfo=R.Tensor((16, 16), dtype="float32"))
771766
R.output(gv)
772767
return gv
773768

@@ -794,8 +789,7 @@ class Expected:
794789
def _initialize_effect() -> R.Tuple(R.Object):
795790
with R.dataflow():
796791
_io: R.Object = R.null_value()
797-
lv: R.Tuple(R.Object) = (_io,)
798-
gv: R.Tuple(R.Object) = lv
792+
gv = (_io,)
799793
R.output(gv)
800794
return gv
801795

@@ -845,7 +839,6 @@ def test(self):
845839

846840
@tvm.testing.requires_gpu
847841
def test_multinomial_from_uniform():
848-
849842
prob_shape = (3, 5)
850843
sample_shape = (6, 1)
851844

@@ -882,8 +875,7 @@ def get_sample_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
882875
def _initialize_effect() -> R.Tuple(R.Object):
883876
with R.dataflow():
884877
_io: R.Object = R.null_value()
885-
lv: R.Tuple(R.Object) = (_io,)
886-
gv: R.Tuple(R.Object) = lv
878+
gv = (_io,)
887879
R.output(gv)
888880
return gv
889881

@@ -1009,8 +1001,7 @@ def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
10091001
def _initialize_effect() -> R.Tuple(R.Object):
10101002
with R.dataflow():
10111003
_io: R.Object = R.null_value()
1012-
lv: R.Tuple(R.Object) = (_io,)
1013-
gv: R.Tuple(R.Object) = lv
1004+
gv = (_io,)
10141005
R.output(gv)
10151006
return gv
10161007

@@ -1124,8 +1115,7 @@ def get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.h
11241115
def _initialize_effect() -> R.Tuple(R.Object):
11251116
with R.dataflow():
11261117
_io: R.Object = R.null_value()
1127-
lv: R.Tuple(R.Object) = (_io,)
1128-
gv: R.Tuple(R.Object) = lv
1118+
gv = (_io,)
11291119
R.output(gv)
11301120
return gv
11311121

tests/python/relax/test_frontend_nn_packing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ def forward(
5959
matmul = R.matmul(x, matmul_1_weight)
6060
matmul_2_weight = R.permute_dims(linear_2_weight)
6161
matmul1 = R.matmul(x, matmul_2_weight)
62-
add = R.add(matmul, matmul1)
63-
gv = add
62+
gv = R.add(matmul, matmul1)
6463
R.output(gv)
6564
return gv
6665

0 commit comments

Comments
 (0)