diff --git a/python/tvm/contrib/msc/core/frontend/translate.py b/python/tvm/contrib/msc/core/frontend/translate.py index 2eaae1335855..63b4424524eb 100644 --- a/python/tvm/contrib/msc/core/frontend/translate.py +++ b/python/tvm/contrib/msc/core/frontend/translate.py @@ -119,6 +119,7 @@ def from_relax( )(mod) patterns = get_patterns_with_prefix("msc.") passes = [ + tvm.relax.transform.ExpandTupleArguments(), msc_transform.SetExprName(), msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)), tvm.relax.transform.FuseOpsByPattern( @@ -310,6 +311,7 @@ def byoc_partition( def _partition_mod(mod, as_msc=True): patterns = get_patterns_with_prefix(target) passes = [ + tvm.relax.transform.ExpandTupleArguments(), msc_transform.SetExprName(), msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)), tvm.relax.transform.FuseOpsByPattern(patterns, bind_constants=not as_msc), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 35131d324076..6d01283d3ecd 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -526,6 +526,22 @@ def _einsum(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.einsum(tuple(args[1]), args[0])) return self.block_builder.emit(relax.op.einsum(args[1:], args[0])) + def _unbind(self, node: fx.node.Node) -> relax.Var: + if len(node.args) == 2: + assert isinstance(node.args[1], int), "Expected 2nd argument of unbind as int" + dim = node.args[1] + elif "dim" in node.kwargs: + dim = node.kwargs["dim"] + else: + dim = 0 + x = self.env[node.args[0]] + selections = self.shape_of(x)[dim].value + n_section = list(range(1, selections + 1)) + ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, dim)) + for i in range(selections): + ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) + return self.block_builder.emit(relax.Tuple(ret)) + ########## Manipulation ########## def _cat(self, node: fx.node.Node) -> relax.Var: @@ -535,7 +551,13 @@ def _cat(self, node: fx.node.Node) -> relax.Var: def _expand(self, node: fx.node.Node) -> relax.Var: args = self.retrieve_args(node) - return self.block_builder.emit(relax.op.broadcast_to(args[0], args[1:])) + broadcast_shape, in_shape = [], self.shape_of(args[0]) + for idx, i in enumerate(args[1:]): + if isinstance(i, int) and i == -1: + broadcast_shape.append(in_shape[idx]) + else: + broadcast_shape.append(i) + return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) def _flatten(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] @@ -580,7 +602,13 @@ def _split(self, node: fx.node.Node) -> relax.Var: dim = node.kwargs["dim"] else: dim = 0 - n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size + if isinstance(split_size, (list, tuple)): + n_section = [] + for s in split_size[:-1]: + cum_sum = 0 if not n_section else n_section[-1] + n_section.append(s + cum_sum) + else: + n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size return self.block_builder.emit(relax.op.split(x, n_section, dim)) def _chunk(self, node: fx.node.Node) -> relax.Var: @@ -1501,6 +1529,7 @@ def create_convert_map(self): "cross_entropy": self._cross_entropy, "scaled_dot_product_attention": self._scaled_dot_product_attention, "einsum": self._einsum, + "unbind": self._unbind, } def update_convert_map(self, custom_convert_map: dict): diff --git a/tests/python/contrib/test_msc/test_graph_build.py b/tests/python/contrib/test_msc/test_graph_build.py index 315d6813ea99..069ffff53bd7 100644 --- a/tests/python/contrib/test_msc/test_graph_build.py +++ b/tests/python/contrib/test_msc/test_graph_build.py @@ -1345,11 +1345,15 @@ def forward(self, x_1, x_2, x_3): def test_split(): """test graph builder for split""" - class Split(Module): + class Split1(Module): def forward(self, data): return torch.split(data, 1, dim=1) - expected = { + class Split2(Module): + def forward(self, data): + return torch.split(data, [1, 2], dim=1) + + expected1 = { "inputs": [ {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], @@ -1361,8 +1365,43 @@ def forward(self, data): "nodes": {"total": 2, "input": 1, "split": 1}, } + expected2 = { + "inputs": [ + {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + ], + "outputs": [ + {"name": "split_0", "shape": [1, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "split_1", "shape": [1, 2, 10, 10], "dtype": "float32", "layout": "ABCD"}, + ], + "nodes": {"total": 2, "input": 1, "split": 1}, + } + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(Split1(), input_info, expected1) + verify_model(Split2(), input_info, expected2) + + +def test_unbind(): + """test graph builder for unbind""" + + class Unbind(Module): + def forward(self, data): + return torch.unbind(data, dim=1) + + expected = { + "inputs": [ + {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + ], + "outputs": [ + {"name": "tuple_0", "shape": [1, 10, 10], "dtype": "float32", "layout": "ACD"}, + {"name": "tuple_1", "shape": [1, 10, 10], "dtype": "float32", "layout": "ACD"}, + {"name": "tuple_2", "shape": [1, 10, 10], "dtype": "float32", "layout": "ACD"}, + ], + "nodes": {"total": 9, "input": 1, "split": 1, "get_item": 3, "squeeze": 3, "tuple": 1}, + } + input_info = [([1, 3, 10, 10], "float32")] - verify_model(Split(), input_info, expected) + verify_model(Unbind(), input_info, expected) def test_cumsum(): @@ -1547,10 +1586,14 @@ def forward(self, x): def test_expand(): """test graph builder for expand""" - class Expand(Module): + class Expand1(Module): def forward(self, x): return x.expand(4, 2, 3, 4) + class Expand2(Module): + def forward(self, x): + return x.expand(4, -1, -1, 4) + expected = { "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": ""}], "outputs": [ @@ -1560,7 +1603,8 @@ def forward(self, x): } input_info = [([1, 2, 3, 4], "float32")] - verify_model(Expand(), input_info, expected) + verify_model(Expand1(), input_info, expected) + verify_model(Expand2(), input_info, expected) def test_reduce(): diff --git a/tests/python/contrib/test_msc/test_pipeline.py b/tests/python/contrib/test_msc/test_pipeline.py index c7a26bf96efb..149041959416 100644 --- a/tests/python/contrib/test_msc/test_pipeline.py +++ b/tests/python/contrib/test_msc/test_pipeline.py @@ -38,7 +38,7 @@ def _get_config(model_type, compile_type, inputs, outputs, dynamic=False, atol=1 path = "test_pipe_{}_{}_{}".format(model_type, compile_type, "dynamic" if dynamic else "static") return { "workspace": msc_utils.msc_dir(path), - "verbose": "info", + "verbose": "critical", "model_type": model_type, "inputs": inputs, "outputs": outputs, diff --git a/tests/python/contrib/test_msc/test_translate_relax.py b/tests/python/contrib/test_msc/test_translate_relax.py index 00975be85eca..e8b7149a68a2 100644 --- a/tests/python/contrib/test_msc/test_translate_relax.py +++ b/tests/python/contrib/test_msc/test_translate_relax.py @@ -67,7 +67,12 @@ def _run_relax(relax_mod): orig_output = _run_relax(orig_mod) rt_output = _run_relax(rt_mod) - tvm.testing.assert_allclose(orig_output, rt_output) + if not isinstance(orig_output, (list, tuple)): + orig_output = [orig_output] + if not isinstance(rt_output, (list, tuple)): + rt_output = [rt_output] + for o_out, r_out in zip(orig_output, rt_output): + tvm.testing.assert_allclose(o_out, r_out) def test_conv1d(): @@ -750,12 +755,33 @@ def forward(self, x_1, x_2, x_3): def test_split(): """test relax translator for split""" - class Split(Module): + class Split1(Module): def forward(self, data): return torch.split(data, 1, dim=1) + class Split2(Module): + def forward(self, data): + return torch.split(data, [1, 2], dim=1) + input_info = [([1, 3, 10, 10], "float32")] - _verify_model(Split(), input_info) + _verify_model(Split1(), input_info) + _verify_model(Split2(), input_info) + + +def test_unbind(): + """test relax translator for unbind""" + + class Unbind1(Module): + def forward(self, data): + return torch.unbind(data) + + class Unbind2(Module): + def forward(self, data): + return torch.unbind(data, dim=1) + + input_info = [([3, 3, 10, 10], "float32")] + _verify_model(Unbind1(), input_info) + _verify_model(Unbind2(), input_info) def test_cumsum(): @@ -874,12 +900,17 @@ def forward(self, x): def test_expand(): """test relax translator for expand""" - class Expand(Module): + class Expand1(Module): def forward(self, x): return x.expand(4, 2, 3, 4) + class Expand2(Module): + def forward(self, x): + return x.expand(4, -1, -1, 4) + input_info = [([1, 2, 3, 4], "float32")] - _verify_model(Expand(), input_info) + _verify_model(Expand1(), input_info) + _verify_model(Expand2(), input_info) def test_reduce(): diff --git a/tests/python/contrib/test_msc/test_translate_relay.py b/tests/python/contrib/test_msc/test_translate_relay.py index 6c47b8b39545..3790da3f3d8e 100644 --- a/tests/python/contrib/test_msc/test_translate_relay.py +++ b/tests/python/contrib/test_msc/test_translate_relay.py @@ -731,12 +731,33 @@ def forward(self, x_1, x_2, x_3): def test_split(): """test relay to relax for split""" - class Split(Module): + class Split1(Module): def forward(self, data): return torch.split(data, 1, dim=1) + class Split2(Module): + def forward(self, data): + return torch.split(data, [1, 2], dim=1) + input_info = [([1, 3, 10, 10], "float32")] - verify_model(Split(), input_info, build_target="llvm") + verify_model(Split1(), input_info, build_target="llvm") + verify_model(Split2(), input_info, build_target="llvm") + + +def test_unbind(): + """test relay to relax for unbind""" + + class Unbind1(Module): + def forward(self, data): + return torch.unbind(data) + + class Unbind2(Module): + def forward(self, data): + return torch.unbind(data, dim=1) + + input_info = [([3, 3, 10, 10], "float32")] + verify_model(Unbind1(), input_info, build_target="llvm") + verify_model(Unbind2(), input_info, build_target="llvm") def test_cumsum(): @@ -859,12 +880,17 @@ def forward(self, x): def test_expand(): """test relay to relax for expand""" - class Expand(Module): + class Expand1(Module): def forward(self, x): return x.expand(4, 2, 3, 4) + class Expand2(Module): + def forward(self, x): + return x.expand(4, -1, -1, 4) + input_info = [([1, 2, 3, 4], "float32")] - verify_model(Expand(), input_info, build_target="llvm") + verify_model(Expand1(), input_info, build_target="llvm") + verify_model(Expand2(), input_info, build_target="llvm") def test_reduce(): diff --git a/tests/python/contrib/test_msc/test_translate_tensorrt.py b/tests/python/contrib/test_msc/test_translate_tensorrt.py index 81104e6fe0f2..74c25ceacfe8 100644 --- a/tests/python/contrib/test_msc/test_translate_tensorrt.py +++ b/tests/python/contrib/test_msc/test_translate_tensorrt.py @@ -673,12 +673,34 @@ def forward(self, x_1, x_2, x_3): def test_split(): """test tensorrt translator for split""" - class Split(Module): + class Split1(Module): def forward(self, data): return torch.split(data, 1, dim=1) + class Split2(Module): + def forward(self, data): + return torch.split(data, [1, 2], dim=1) + input_info = [([1, 3, 10, 10], "float32")] - verify_model(Split(), input_info) + verify_model(Split1(), input_info) + verify_model(Split2(), input_info) + + +@requires_tensorrt +def test_unbind(): + """test tensorrt to relax for unbind""" + + class Unbind1(Module): + def forward(self, data): + return torch.unbind(data) + + class Unbind2(Module): + def forward(self, data): + return torch.unbind(data, dim=1) + + input_info = [([3, 3, 10, 10], "float32")] + verify_model(Unbind1(), input_info) + verify_model(Unbind2(), input_info) @requires_tensorrt @@ -697,13 +719,19 @@ def forward(self, data): def test_expand(): """test tensorrt translator for expand""" - class Expand(Module): + class Expand1(Module): def forward(self, x): x = x + 1.0 return x.expand(4, 2, 3, 4) + class Expand2(Module): + def forward(self, x): + x = x + 1.0 + return x.expand(4, -1, -1, 4) + input_info = [([1, 2, 3, 4], "float32")] - verify_model(Expand(), input_info) + verify_model(Expand1(), input_info) + verify_model(Expand2(), input_info) @requires_tensorrt diff --git a/tests/python/contrib/test_msc/test_translate_torch.py b/tests/python/contrib/test_msc/test_translate_torch.py index 81c6031ce17a..60dcbb293a51 100644 --- a/tests/python/contrib/test_msc/test_translate_torch.py +++ b/tests/python/contrib/test_msc/test_translate_torch.py @@ -728,13 +728,35 @@ def forward(self, x_1, x_2, x_3): def test_split(): """test torch translator for split""" - class Split(Module): + class Split1(Module): def forward(self, data): return torch.split(data, 1, dim=1) + class Split2(Module): + def forward(self, data): + return torch.split(data, [1, 2], dim=1) + input_info = [([1, 3, 10, 10], "float32")] for via_relax in [True, False]: - verify_model(Split(), input_info, via_relax) + verify_model(Split1(), input_info, via_relax) + verify_model(Split2(), input_info, via_relax) + + +def test_unbind(): + """test torch translator for unbind""" + + class Unbind1(Module): + def forward(self, data): + return torch.unbind(data) + + class Unbind2(Module): + def forward(self, data): + return torch.unbind(data, dim=1) + + input_info = [([3, 3, 10, 10], "float32")] + for via_relax in [True, False]: + verify_model(Unbind1(), input_info, via_relax) + verify_model(Unbind2(), input_info, via_relax) def test_cumsum(): @@ -835,13 +857,18 @@ def forward(self, x): def test_expand(): """test torch translator for expand""" - class Expand(Module): + class Expand1(Module): def forward(self, x): return x.expand(4, 2, 3, 4) + class Expand2(Module): + def forward(self, x): + return x.expand(4, -1, -1, 4) + input_info = [([1, 2, 3, 4], "float32")] for via_relax in [True, False]: - verify_model(Expand(), input_info, via_relax) + verify_model(Expand1(), input_info, via_relax) + verify_model(Expand2(), input_info, via_relax) def test_reduce(): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 6be3e7b23e9d..5398fe342073 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -2714,10 +2714,14 @@ def main( def test_split(): input_info = [([1, 3, 10, 10], "float32")] - class Split(Module): + class Split1(Module): def forward(self, input): return torch.split(input, 1, dim=1) + class Split2(Module): + def forward(self, input): + return torch.split(input, [1, 2], dim=1) + @tvm.script.ir_module class expected1: @R.function @@ -2743,7 +2747,118 @@ def main( R.output(gv) return gv - verify_model(Split(), input_info, {}, expected1) + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 2, 10, 10), dtype="float32") + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 2, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=[1], axis=1) + gv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 2, 10, 10), dtype="float32"), + ) = lv + R.output(gv) + return gv + + verify_model(Split1(), input_info, {}, expected1) + verify_model(Split2(), input_info, {}, expected2) + + +def test_unbind(): + input_info = [([3, 3, 10, 10], "float32")] + + class Unbind1(Module): + def forward(self, data): + return torch.unbind(data) + + class Unbind2(Module): + def forward(self, data): + return torch.unbind(data, dim=1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((0, 3, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[0]) + lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2] + lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[0]) + lv7: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv2, lv4, lv6) + gv: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = lv7 + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 0, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1) + lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) + lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1] + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[1]) + lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2] + lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[1]) + lv7: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv2, lv4, lv6) + gv: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = lv7 + R.output(gv) + return gv + + verify_model(Unbind1(), input_info, {}, expected1) + verify_model(Unbind2(), input_info, {}, expected2) def test_cumsum(): @@ -2970,10 +3085,14 @@ def main(x: R.Tensor((1, 2, 3), dtype="float32")) -> R.Tensor((1, 2, 3), dtype=" def test_expand(): input_info = [([1, 2, 3, 4], "float32")] - class Expand(Module): + class Expand1(Module): def forward(self, x): return x.expand(4, 2, 3, 4) + class Expand2(Module): + def forward(self, x): + return x.expand(4, -1, -1, 4) + @tvm.script.ir_module class expected1: @R.function @@ -2987,7 +3106,8 @@ def main( R.output(gv) return gv - verify_model(Expand(), input_info, {}, expected1) + verify_model(Expand1(), input_info, {}, expected1) + verify_model(Expand2(), input_info, {}, expected1) def test_reduce():