From e85d05ffae3d9a05c5af327ebf644a9ea53d26f2 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 24 Mar 2024 14:09:13 -0400 Subject: [PATCH] Revert "[SLM] Allow modules to define pre-processing of weights (#16757)" This reverts commit 1cccc3b5d65cae743a2becb7e256c05897af29ca. --- python/tvm/relax/frontend/nn/core.py | 17 +- python/tvm/relax/frontend/nn/exporter.py | 40 +- .../python/relax/test_frontend_nn_exporter.py | 443 ------------------ .../relax/test_frontend_nn_extern_module.py | 10 +- .../python/relax/test_frontend_nn_modules.py | 3 +- tests/python/relax/test_frontend_nn_op.py | 27 +- .../python/relax/test_frontend_nn_packing.py | 3 +- .../relax/test_frontend_nn_subroutines.py | 13 +- 8 files changed, 58 insertions(+), 498 deletions(-) delete mode 100644 tests/python/relax/test_frontend_nn_exporter.py diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index 820acd235d8c..b7b3f411ed41 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -591,22 +591,7 @@ def wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, Sequence[Tensor]]: The computed result. """ if not isinstance(expr, rx.DataflowVar): - block_builder = BlockBuilder.current() - if block_builder is None: - # Normalize to make sure we have valid StructInfo, but - # wait until we are actually building the function to - # flatten nested expressions. - # - # TODO(Lunderberg): Make this easier to call. Infering - # struct info for a nested expression should be doable in - # a free function, without requiring an active - # BlockBuilder and an active FunctionFrame. - builder = BlockBuilder() - with builder.function("dummy_scope", params=[]): - expr = builder.normalize(expr) - builder.emit_func_output([]) - else: - expr = BlockBuilder.current().emit(expr, name) + expr = BlockBuilder.current().emit(expr, name) if isinstance(expr.struct_info_, TensorStructInfo): return Tensor(_expr=expr) if isinstance(expr.struct_info_, TupleStructInfo): diff --git a/python/tvm/relax/frontend/nn/exporter.py b/python/tvm/relax/frontend/nn/exporter.py index 525d689f4995..1a7dcd6a648b 100644 --- a/python/tvm/relax/frontend/nn/exporter.py +++ b/python/tvm/relax/frontend/nn/exporter.py @@ -111,8 +111,7 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]: return result # pylint: enable=protected-access - - params = _params() + params = None effects = _effects() ext_mods = self.extern_mods with self: @@ -122,6 +121,7 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]: outputs = _emit_effect_init(self.builder, effects) self.builder.emit_func_output(outputs, params=[]) for method_name, method_spec in zip(spec.method_names, spec.method_specs): + params = _params() # Re-initialize so symbolic shapes not shared across methods len_args = len(method_spec.arg_specs) len_effects = { "packed": 1, @@ -135,18 +135,9 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]: with self.builder.dataflow(): outputs, inputs = _emit_method(self.builder, method_spec, params, effects) self.builder.emit_func_output(outputs, inputs) - - # TODO(Lunderberg): Make a `ir.transform.ConvertSSA`, - # similar to the existing `tir.transform.ConvertSSA`, - # that converts an entire module to SSA, including TIR - # variable definitions used in either TIR or Relax. - mod = self.builder.get() - mod[method_name] = rx.utils.copy_with_new_vars(mod[method_name]) - mod = self.builder.finalize() assert rx.analysis.well_formed(mod) - mod = rx.transform.CanonicalizeBindings()(mod) return mod, params, ext_mods @@ -170,6 +161,8 @@ def _emit_method( # pylint: disable=too-many-locals,too-many-branches,too-many- effects: typing.Optional[typing.List[typing.Tuple[str, core.Effect]]], ): # pylint: disable=protected-access + # symbolic shape's name mapping to its tir.Var for reuse + str2var_params: typing.Dict[str, tir.Var] = {} def _unwrap_ret(expr: typing.Any) -> typing.Any: if isinstance(expr, (core.Tensor, core.Object)): @@ -183,26 +176,35 @@ def _unwrap_ret(expr: typing.Any) -> typing.Any: def _convert_input(arg): if isinstance(arg, tir.Var): return rx.Var(arg.name, struct_info=ShapeStructInfo(values=[arg])) - elif isinstance(arg, (core.Tensor, core.Object)): + if isinstance(arg, (core.Tensor, core.Object)): return arg._expr # pylint: disable=protected-access - elif isinstance(arg, _spec.Tuple): + if isinstance(arg, _spec.Tuple): return rx.Var( arg.name, struct_info=TupleStructInfo( [_convert_input(arg_i).struct_info for arg_i in arg.elements] ), ) - elif isinstance(arg, rx.Expr): - return arg - else: - raise TypeError(f"Unsupported input type: {type(arg)}") + raise TypeError(f"Unsupported input type: {type(arg)}") def _params(mode: str) -> typing.List[rx.Var]: inputs: typing.List[rx.Var] = [] - for name, param in params: - inputs.append(param._expr) + def _get_var(shape_var: tir.Var) -> tir.Var: + name = shape_var.name + if name in str2var_params: + return str2var_params[name] + var = tir.Var(name, "int64") + str2var_params[name] = var + return var + for name, param in params: + # Make sure the a symbolic shape is not re-registered (same as _method_spec_to_inputs) + # e.g. we do not see `vocab_size` for `lm_head` and `vocab_size_1` for `embed_tokens` + new_shape = [_get_var(x) if isinstance(x, tir.Var) else x for x in param.shape] + var = core.Tensor.placeholder(new_shape, param.dtype, name)._expr + inputs.append(var) + param._expr = var if mode == "none": return [] if mode == "plain": diff --git a/tests/python/relax/test_frontend_nn_exporter.py b/tests/python/relax/test_frontend_nn_exporter.py deleted file mode 100644 index de8900238bb6..000000000000 --- a/tests/python/relax/test_frontend_nn_exporter.py +++ /dev/null @@ -1,443 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -import tvm -import tvm.testing - -from tvm import relax, tir -from tvm.ir import assert_structural_equal -from tvm.relax.frontend import nn -from tvm.script import ir as I, relax as R, tir as T - - -def test_simple(): - """A module may be exported from nn.Module to Relax""" - - slm_mod = nn.modules.ReLU() - exported_mod, _ = slm_mod.export_tvm( - spec={"forward": {"x": nn.spec.Tensor((3, 3), "float32")}}, - debug=False, - ) - - @I.ir_module - class Expected: - @R.function - def forward(x: R.Tensor([3, 3], dtype="float32")): - R.func_attr({"num_input": 1}) - with R.dataflow(): - relu = R.nn.relu(x) - R.output(relu) - return relu - - assert_structural_equal(exported_mod, Expected) - - -def test_custom_module(): - """A module may be exported from nn.Module to Relax""" - - class Before(nn.Module): - def forward(self, x: R.Tensor): - return nn.op.relu(x) - - slm_mod = Before() - exported_mod, _ = slm_mod.export_tvm( - spec={"forward": {"x": nn.spec.Tensor((3, 3), "float32")}}, - debug=False, - ) - - @I.ir_module - class Expected: - @R.function - def forward(x: R.Tensor([3, 3], dtype="float32")): - R.func_attr({"num_input": 1}) - with R.dataflow(): - relu = R.nn.relu(x) - R.output(relu) - return relu - - assert_structural_equal(exported_mod, Expected) - - -def test_debug_effect(): - """Passing debug=True provides an argument for IO effect""" - - slm_mod = nn.modules.ReLU() - exported_mod, _ = slm_mod.export_tvm( - spec={"forward": {"x": nn.spec.Tensor((3, 3), "float32")}}, - debug=True, - ) - - @I.ir_module - class Expected: - @R.function - def forward( - x: R.Tensor([3, 3], dtype="float32"), - _io: R.Object, - ): - R.func_attr({"num_input": 2}) - with R.dataflow(): - relu = R.nn.relu(x) - output = relu, (_io,) - R.output(output) - return output - - @R.function - def _initialize_effect(): - with R.dataflow(): - _io = R.null_value() - output = (_io,) - R.output(output) - return output - - assert_structural_equal(exported_mod, Expected) - - -def test_dynamic_shape(): - """An argument may have a dynamic shape""" - - slm_mod = nn.modules.ReLU() - exported_mod, _ = slm_mod.export_tvm( - spec={"forward": {"x": nn.spec.Tensor([tir.Var("batch_size", "int64"), 8], "float32")}}, - debug=False, - ) - - @I.ir_module - class Expected: - @R.function - def forward(x: R.Tensor(["batch_size", 8], dtype="float32")): - R.func_attr({"num_input": 1}) - with R.dataflow(): - relu = R.nn.relu(x) - R.output(relu) - return relu - - assert_structural_equal(exported_mod, Expected) - - -def test_dynamic_shape_in_multiple_functions(): - """A dynamic shape may be used in multiple functions""" - - class Before(nn.Module): - def forward_relu(self, x: nn.Tensor): - return nn.relu(x) - - def forward_silu(self, x: nn.Tensor): - return nn.silu(x) - - slm_mod = Before() - exported_mod, _ = slm_mod.export_tvm( - spec={ - "forward_relu": {"x": nn.spec.Tensor((tir.Var("batch_size", "int64"), 8), "float32")}, - "forward_silu": {"x": nn.spec.Tensor((tir.Var("batch_size", "int64"), 8), "float32")}, - }, - debug=False, - ) - - @I.ir_module - class Expected: - @R.function - def forward_relu(x: R.Tensor(["batch_size", 8], dtype="float32")): - R.func_attr({"num_input": 1}) - with R.dataflow(): - relu = R.nn.relu(x) - R.output(relu) - return relu - - @R.function - def forward_silu(x: R.Tensor(["batch_size", 8], dtype="float32")): - R.func_attr({"num_input": 1}) - with R.dataflow(): - silu = R.nn.silu(x) - R.output(silu) - return silu - - assert_structural_equal(exported_mod, Expected) - - -def test_export_nested_module(): - """nn.Module instances may contain other nn.Module - - When exporting to a Relax IRModule, all `nn.Parameter` instances - within the `nn.Module` become Relax function parameters. - """ - - class LlamaMLP(nn.Module): - def __init__(self, hidden_size: int, intermediate_size: int): - super().__init__() - self.gate_proj = nn.Linear( - in_features=hidden_size, - out_features=intermediate_size, - dtype="float16", - bias=False, - ) - self.up_proj = nn.Linear( - in_features=hidden_size, - out_features=intermediate_size, - dtype="float16", - bias=False, - ) - self.down_proj = nn.Linear( - intermediate_size, - hidden_size, - dtype="float16", - bias=False, - ) - - def forward(self, x: nn.Tensor): - gate = self.gate_proj(x) - up = self.up_proj(x) - return self.down_proj(nn.op.silu(gate) * up) - - hidden_size = 4096 - intermediate_size = 11008 - slm_mod = LlamaMLP(hidden_size=hidden_size, intermediate_size=intermediate_size) - exported_mod, _ = slm_mod.export_tvm( - spec={ - "forward": { - "x": nn.spec.Tensor((tir.Var("batch_size", "int64"), hidden_size), "float16") - }, - }, - debug=False, - ) - - @I.ir_module - class Expected: - @R.function - def forward( - x: R.Tensor(["batch_size", hidden_size], "float16"), - gate_proj_weights: R.Tensor([intermediate_size, hidden_size], "float16"), - up_proj_weights: R.Tensor([intermediate_size, hidden_size], "float16"), - down_proj_weights: R.Tensor([hidden_size, intermediate_size], "float16"), - ): - R.func_attr({"num_input": 1}) - batch_size = T.int64() - with R.dataflow(): - gate: R.Tensor([batch_size, intermediate_size]) = R.matmul( - x, R.permute_dims(gate_proj_weights) - ) - up: R.Tensor([batch_size, intermediate_size]) = R.matmul( - x, R.permute_dims(up_proj_weights) - ) - down: R.Tensor([batch_size, hidden_size]) = R.matmul( - R.nn.silu(gate) * up, R.permute_dims(down_proj_weights) - ) - R.output(down) - return down - - assert_structural_equal(exported_mod, Expected) - - -def test_generate_parameters(): - """Weights may be expressions in terms of other parameters - - Optimizations often require preprocessing of the model weights. - - 1. Declare the `nn.Module` members that contain the original model - weights. These are used to define the parameter names when - reading from a Pytorch or Safetensors file. - - 2. Declare the `nn.Module` members, with the `weight` field - in terms of the un-optimized weights. These `nn.Module` - do not generate any parameters in the Relax function. - - 3. Define the `forward` function in terms of the `nn.Module` - members for the updated weight tensors. - - The exported Relax function accepts the original model parameters, - computes the pre-processed weights, and then performs computations - using the pre-processed weights. - - In this example, the `LiftTransformParams` transform is applied - immediately, splitting the Relax function into a pre-processing - step and an execution step. In practice, this transform would be - applied much later in an optimization pipeline, to allow optimized - compute kernels to be recognized. For example, in some cases - `R.matmul(x, R.permute_dims(weight))` may be computed more - efficiently than `R.matmul(x, weight_transpose)`. For this - reason, we do *not* apply `LiftTransformParams` as part of the - export from `nn.Module` to Relax. - - """ - - class LlamaMLP(nn.Module): - def __init__(self, hidden_size: int, intermediate_size: int): - super().__init__() - # The nn.Linear for the original parameters are present in - # the model definition, and are still found when - # collecting a function's parameters. - self.gate_proj = nn.Linear( - in_features=hidden_size, - out_features=intermediate_size, - dtype="float16", - bias=False, - ) - self.up_proj = nn.Linear( - in_features=hidden_size, - out_features=intermediate_size, - dtype="float16", - bias=False, - ) - self.down_proj = nn.Linear( - intermediate_size, - hidden_size, - dtype="float16", - bias=False, - ) - - # At runtime, we'd like to have a single concatenated - # tensor containing both the gate and up projection - # weights. We also want to use it in the `forward` - # function as if it owned its own weights. - self.gate_up_proj = nn.Linear( - in_features=hidden_size, - out_features=intermediate_size, - dtype="float16", - bias=False, - ) - - # The weight tensor of `gate_up_proj` can be overwritten - # in terms of the original `gate_proj` and `up_proj` - # tensors. - self.gate_up_proj.weight = nn.op.concat( - [self.gate_proj.weight, self.up_proj.weight], dim=0, name="gate_up_proj_weights" - ) - - def forward(self, x: nn.Tensor): - # Even though the `gate_up_proj` weights are defined as an - # expression rather than a `nn.Parameter`, the `forward` - # function does not require any special handling for it. - concat_gate_up = self.gate_up_proj(x) - gate, up = nn.op.split(concat_gate_up, 2, axis=-1) - return self.down_proj(nn.op.silu(gate) * up) - - hidden_size = 4096 - intermediate_size = 11008 - slm_mod = LlamaMLP(hidden_size=hidden_size, intermediate_size=intermediate_size) - exported_mod, _ = slm_mod.export_tvm( - spec={ - "forward": { - "x": nn.spec.Tensor((tir.Var("batch_size", "int64"), hidden_size), "float16") - }, - }, - debug=False, - ) - - @I.ir_module - class Expected: - @R.function - def forward( - x: R.Tensor(["batch_size", hidden_size], "float16"), - # The function's parameters are defined by the - # `nn.Parameter` instances, and still reference the - # original `gate_proj` and `up_proj` weights. This - # maintains compatibility with named model weights in a - # Pytorch or Safetensors file. - gate_proj_weights: R.Tensor([intermediate_size, hidden_size], "float16"), - up_proj_weights: R.Tensor([intermediate_size, hidden_size], "float16"), - down_proj_weights: R.Tensor([hidden_size, intermediate_size], "float16"), - ): - R.func_attr({"num_input": 1}) - batch_size = T.int64() - with R.dataflow(): - # At this stage of compilation, the concatenation is - # written within the body of the function. This will - # later be extracted into a pre-processing step using - # `relax.transform.LiftTransformParams`. - gate_up_proj_weights: R.Tensor( - [intermediate_size * 2, hidden_size], "float16" - ) = R.concat([gate_proj_weights, up_proj_weights], axis=0) - gate_up: R.Tensor([batch_size, intermediate_size * 2], "float16") = R.matmul( - x, R.permute_dims(gate_up_proj_weights) - ) - gate_up_split = R.split(gate_up, 2, axis=-1) - gate = gate_up_split[0] - up = gate_up_split[1] - down: R.Tensor([batch_size, hidden_size], "float16") = R.matmul( - R.nn.silu(gate) * up, R.permute_dims(down_proj_weights) - ) - R.output(down) - return down - - assert_structural_equal(exported_mod, Expected) - - @I.ir_module - class ExpectedAfterLift: - @R.function - def forward( - x: R.Tensor(["batch_size", hidden_size], "float16"), - # After `relax.transform.LiftTransformParams`, the - # `gate_proj` and `up_proj` weights have been concatenated - # together. - gate_up_proj_weights_transpose: R.Tensor( - [hidden_size, intermediate_size * 2], "float16" - ), - down_proj_weights_transpose: R.Tensor([intermediate_size, hidden_size], "float16"), - ): - R.func_attr({"num_input": 1}) - batch_size = T.int64() - with R.dataflow(): - gate_up: R.Tensor([batch_size, intermediate_size * 2], "float16") = R.matmul( - x, gate_up_proj_weights_transpose - ) - gate_up_split = R.split(gate_up, 2, axis=-1) - gate = gate_up_split[0] - up = gate_up_split[1] - down: R.Tensor([batch_size, hidden_size], "float16") = R.matmul( - R.nn.silu(gate) * up, down_proj_weights_transpose - ) - R.output(down) - return down - - @R.function - def transform_params( - model_params: R.Tuple( - R.Tensor([intermediate_size, hidden_size], "float16"), - R.Tensor([intermediate_size, hidden_size], "float16"), - R.Tensor([hidden_size, intermediate_size], "float16"), - ) - ): - R.func_attr({"num_input": 0}) - with R.dataflow(): - gate_proj_weights: R.Tensor( - [intermediate_size, hidden_size], "float16" - ) = model_params[0] - up_proj_weights: R.Tensor( - [intermediate_size, hidden_size], "float16" - ) = model_params[1] - gate_up_proj_weights: R.Tensor( - [intermediate_size * 2, hidden_size], "float16" - ) = R.concat([gate_proj_weights, up_proj_weights], axis=0) - gate_up_proj_weights_transpose: R.Tensor( - [hidden_size, intermediate_size * 2], "float16" - ) = R.permute_dims(gate_up_proj_weights) - down_proj_weights: R.Tensor( - [hidden_size, intermediate_size], "float16" - ) = model_params[2] - down_proj_weights_transpose: R.Tensor( - [intermediate_size, hidden_size], "float16" - ) = R.permute_dims(down_proj_weights) - output = (gate_up_proj_weights_transpose, down_proj_weights_transpose) - R.output(output) - return output - - lifted_mod = relax.transform.LiftTransformParams(shared_transform=True)(exported_mod) - assert_structural_equal(lifted_mod, ExpectedAfterLift) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/relax/test_frontend_nn_extern_module.py b/tests/python/relax/test_frontend_nn_extern_module.py index 6ca774242274..6eaf1fbfc805 100644 --- a/tests/python/relax/test_frontend_nn_extern_module.py +++ b/tests/python/relax/test_frontend_nn_extern_module.py @@ -94,8 +94,9 @@ def scalar_add( ext_scalar_add = R.call_dps_packed( "ext_scalar_add", (a, b), out_sinfo=R.Tensor((), dtype="float32") ) - R.output(ext_scalar_add) - return ext_scalar_add + gv: R.Tensor((), dtype="float32") = ext_scalar_add + R.output(gv) + return gv @R.function def test_sym( @@ -109,8 +110,9 @@ def test_sym( ext_test_sym = R.call_dps_packed( "ext_test_sym", (a, b), out_sinfo=R.Tensor((x, y, z, 9), dtype="float32") ) - R.output(ext_test_sym) - return ext_test_sym + gv1: R.Tensor((x, y, z, 9), dtype="float32") = ext_test_sym + R.output(gv1) + return gv1 tvm.ir.assert_structural_equal(ExpectedModule, mod) diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py index 45128749e23d..5ddc10505591 100644 --- a/tests/python/relax/test_frontend_nn_modules.py +++ b/tests/python/relax/test_frontend_nn_modules.py @@ -493,7 +493,8 @@ def _initialize_effect() -> R.Tuple(R.Object, R.Object): R.prim_value(0), sinfo_args=[R.Object()], ) - gv = _io, cache + lv1 = _io, cache + gv = lv1 R.output(gv) return gv diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 68f86bba50e8..7d78e47c945b 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -538,7 +538,8 @@ def add_one(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), T_add: T.Buffer( def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - gv = (_io,) + lv: R.Tuple(R.Object) = (_io,) + gv: R.Tuple(R.Object) = lv R.output(gv) return gv @@ -610,7 +611,8 @@ def llama_fused_rope(var_qkv: T.handle, offset: T.int64, var_q: T.handle, var_k: def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - gv = (_io,) + lv: R.Tuple(R.Object) = (_io,) + gv: R.Tuple(R.Object) = lv R.output(gv) return gv @@ -697,7 +699,8 @@ def inplace_take( def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - gv = (_io,) + lv: R.Tuple(R.Object) = (_io,) + gv: R.Tuple(R.Object) = lv R.output(gv) return gv @@ -714,12 +717,13 @@ def test( R.func_attr({"num_input": 4}) cls = Expected with R.dataflow(): - gv1 = R.call_tir( + lv1 = R.call_tir( cls.inplace_take, (embedding_table, input_ids, embedding_dst), out_sinfo=R.Tensor((total_seq_len, hidden_size), dtype), tir_vars=R.shape([offset_1]), ) + gv1: R.Tensor((total_seq_len, hidden_size), dtype) = lv1 R.output(gv1) return gv1 @@ -768,7 +772,8 @@ def test(A: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="fl R.func_attr({"num_input": 1}) cls = Expected with R.dataflow(): - gv = R.call_tir(cls.tir_func, (A,), out_sinfo=R.Tensor((16, 16), dtype="float32")) + lv = R.call_tir(cls.tir_func, (A,), out_sinfo=R.Tensor((16, 16), dtype="float32")) + gv: R.Tensor((16, 16), dtype="float32") = lv R.output(gv) return gv @@ -795,7 +800,8 @@ class Expected: def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - gv = (_io,) + lv: R.Tuple(R.Object) = (_io,) + gv: R.Tuple(R.Object) = lv R.output(gv) return gv @@ -882,7 +888,8 @@ def get_sample_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle): def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - gv = (_io,) + lv: R.Tuple(R.Object) = (_io,) + gv: R.Tuple(R.Object) = lv R.output(gv) return gv @@ -1008,7 +1015,8 @@ def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle): def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - gv: R.Tuple(R.Object) = (_io,) + lv: R.Tuple(R.Object) = (_io,) + gv: R.Tuple(R.Object) = lv R.output(gv) return gv @@ -1122,7 +1130,8 @@ def get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.h def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - gv: R.Tuple(R.Object) = (_io,) + lv: R.Tuple(R.Object) = (_io,) + gv: R.Tuple(R.Object) = lv R.output(gv) return gv diff --git a/tests/python/relax/test_frontend_nn_packing.py b/tests/python/relax/test_frontend_nn_packing.py index c2cc22c17d40..56b614a807b8 100644 --- a/tests/python/relax/test_frontend_nn_packing.py +++ b/tests/python/relax/test_frontend_nn_packing.py @@ -59,7 +59,8 @@ def forward( matmul = R.matmul(x, matmul_1_weight) matmul_2_weight = R.permute_dims(linear_2_weight) matmul1 = R.matmul(x, matmul_2_weight) - gv = R.add(matmul, matmul1) + add = R.add(matmul, matmul1) + gv = add R.output(gv) return gv diff --git a/tests/python/relax/test_frontend_nn_subroutines.py b/tests/python/relax/test_frontend_nn_subroutines.py index 32ae967916a8..6bbf57aeadde 100644 --- a/tests/python/relax/test_frontend_nn_subroutines.py +++ b/tests/python/relax/test_frontend_nn_subroutines.py @@ -61,7 +61,8 @@ def forward( def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - gv = (_io,) + lv: R.Tuple(R.Object) = (_io,) + gv: R.Tuple(R.Object) = lv R.output(gv) return gv @@ -74,8 +75,9 @@ def layer( with R.dataflow(): state = R.matmul(state, weights) state = Expected.activation(state) - R.output(state) - return state + dataflow_output = state + R.output(dataflow_output) + return dataflow_output @R.function(private=True) def activation( @@ -83,8 +85,9 @@ def activation( ) -> R.Tensor(("batch_size", 32), dtype="float32"): with R.dataflow(): state = R.nn.silu(state) - R.output(state) - return state + dataflow_output = state + R.output(dataflow_output) + return dataflow_output mod = Layer(64, 32) batch_size = tvm.tir.Var("batch_size", "int64")