From 9e302917e8dacdbd586ca9fe6c0160127cddca26 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sat, 8 Mar 2025 16:42:27 -0500 Subject: [PATCH 1/7] add capability for upsample with size OR scale_factor + unit test --- .../torch/exported_program_translator.py | 30 ++++-- .../relax/test_from_exported_to_cuda.py | 101 ++++++++++++++++++ 2 files changed, 123 insertions(+), 8 deletions(-) create mode 100644 tests/python/relax/test_from_exported_to_cuda.py diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c8d9d12505c6..a2c14280bb5f 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -92,7 +92,7 @@ def _group_norm(self, node: fx.Node) -> relax.Var: ) def _upsample_impl( - self, x: relax.Expr, size, align_corners: bool, scale_factor, method: str + self, x: relax.Expr, size, scale_factor, method: str, align_corners: bool, ) -> relax.Var: coord_trans = "align_corners" if align_corners else "half_pixel" @@ -116,21 +116,35 @@ def _upsample_impl( def _upsample_bilinear2d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) + # TODO do we need the condition on size like we have in _upsample_nearest2d? + + # TODO HL: I am doubtful that align_corners is args[2]. The pytorch + # arguments go size, scale_factor, mode, align_corner. See changes I + # made to _upsample_nearest2d. Need to test for _upsample_bilinear2d align_corners = ( node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", True) ) - scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", None) - return self._upsample_impl(x, size, align_corners, scale_factor, "linear") + scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", 1) # TODO non-sense to default to 1! Understand why "None" doesn't work! + return self._upsample_impl(x, size=size, scale_factor=scale_factor, + method="linear", align_corners=align_corners) def _upsample_nearest2d(self, node: fx.node) -> relax.Var: x = self.env[node.args[0]] size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) - align_corners = ( - node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", True) - ) - scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", None) - return self._upsample_impl(x, size, align_corners, scale_factor, "nearest_neighbor") + if size: + scale_factor = None # Can only define size or scale_factor, not both + align_corners = node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", None) + + else: + # TODO figure out why pytorch passes a list [scale_factor,scale_factor] instead of just an int. Passing first element for now + scale_factor = node.args[2][0] if len(node.args) > 2 else node.kwargs.get("scale_factor", 1) # TODO non-sense to default to 1! Understand why "None" doesn't work! + align_corners = node.args[3] if len(node.args) > 3 else node.kwargs.get("align_corners", None) # TODO pytorch defaults to None + + return self._upsample_impl(x, size=size, scale_factor=scale_factor, + method="nearest_neighbor", + align_corners=align_corners) + ########## Manipulation ########## def _select(self, node: fx.Node) -> relax.Var: diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py new file mode 100644 index 000000000000..611d5d302248 --- /dev/null +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +import tvm.testing +import numpy as np +import torch +from torch.export import export +from tvm.relax.frontend.torch import from_exported_program +from torch.nn import Softmax, Upsample + + +def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): + """ + This util ensures that a torch module can successfully be exported to TVM + using torch.export and that the resuling IR program gives the same result + as PyTorch when ran on CUDA. + """ + torch_data = torch.from_numpy(raw_data) + example_args = (torch_data,) + + with torch.no_grad(): + exported_program = export(torch_module, example_args) + mod_from_torch = from_exported_program( + exported_program, keep_params_as_input=True + ) + + tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) + target = tvm.target.Target.from_device(tvm.cuda()) + + ex = relax.build(tvm_mod, target=target, + relax_pipeline=relax.get_default_pipeline(target)) + dev = tvm.device("cuda", 0) + vm = relax.VirtualMachine(ex, dev) + + gpu_data = tvm.nd.array(raw_data, dev) + gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] + gpu_out = vm["main"](gpu_data, *gpu_params) + + pytorch_out = torch_module(torch_data).detach().numpy() + actual = gpu_out[0].numpy() + desired = pytorch_out + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, + atol=1e-5) + + +def test_upsample_with_size(): + """ + The Upsample module can be used with the size arugment or the scale + factor argument but not both. This tests the former. + """ + batch_size = 1 + channels = 3 + height, width = 8, 8 + + torch_module = Upsample( + size=(64, 64), + mode='nearest', + recompute_scale_factor=None) + + raw_data = np.random.rand( + batch_size, channels, height, width).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module) + + +def test_upsample_with_scale_factor(): + """ + The Upsample module can be used with the size arugment or the scale + factor argument but not both. This tests the latter. + """ + batch_size = 2 + channels = 3 + height, width = 32, 32 + + torch_module = Upsample(size=None, scale_factor = 7, mode = 'nearest', + align_corners=None, recompute_scale_factor=True) + + raw_data = np.random.rand( + batch_size, channels, height, width).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module) + + +if __name__ == "__main__": + tvm.testing.main() \ No newline at end of file From c26b504167acb01189765448f06af515ba50b565 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sat, 8 Mar 2025 16:47:46 -0500 Subject: [PATCH 2/7] update TODOs for upsample --- .../frontend/torch/exported_program_translator.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index a2c14280bb5f..294f773f8616 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -116,15 +116,10 @@ def _upsample_impl( def _upsample_bilinear2d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) - # TODO do we need the condition on size like we have in _upsample_nearest2d? - - # TODO HL: I am doubtful that align_corners is args[2]. The pytorch - # arguments go size, scale_factor, mode, align_corner. See changes I - # made to _upsample_nearest2d. Need to test for _upsample_bilinear2d align_corners = ( node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", True) ) - scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", 1) # TODO non-sense to default to 1! Understand why "None" doesn't work! + scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", 1) return self._upsample_impl(x, size=size, scale_factor=scale_factor, method="linear", align_corners=align_corners) @@ -137,9 +132,11 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var: align_corners = node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", None) else: - # TODO figure out why pytorch passes a list [scale_factor,scale_factor] instead of just an int. Passing first element for now - scale_factor = node.args[2][0] if len(node.args) > 2 else node.kwargs.get("scale_factor", 1) # TODO non-sense to default to 1! Understand why "None" doesn't work! - align_corners = node.args[3] if len(node.args) > 3 else node.kwargs.get("align_corners", None) # TODO pytorch defaults to None + # TODO figure out why pytorch export passes a list such as + # [scale_factor,scale_factor] instead of just an int for + # scale_factor. Using first element for now + scale_factor = node.args[2][0] if len(node.args) > 2 else node.kwargs.get("scale_factor", 1) + align_corners = node.args[3] if len(node.args) > 3 else node.kwargs.get("align_corners", None) return self._upsample_impl(x, size=size, scale_factor=scale_factor, method="nearest_neighbor", From 94c80dd3d05869e43bb180c948feba4f5514df14 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 9 Mar 2025 17:43:18 -0400 Subject: [PATCH 3/7] Black formatter --- .../torch/base_fx_graph_translator.py | 6 +- .../torch/exported_program_translator.py | 46 ++++-- .../relax/test_from_exported_to_cuda.py | 42 +++--- .../test_frontend_from_exported_program.py | 136 ++++++++---------- 4 files changed, 114 insertions(+), 116 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 003ceebec6ff..d4885c8b4875 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -37,9 +37,9 @@ def __init__(self) -> None: self.env: Dict[fx.Node, relax.Expr] = {} self.params: Dict[torch.Tensor, relax.Expr] = {} self.block_builder: relax.BlockBuilder = None - self.convert_map: Dict[ - Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var] - ] = self.create_convert_map() + self.convert_map: Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]] = ( + self.create_convert_map() + ) ########## Utilities ########## diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 294f773f8616..1ca209aaf39d 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -92,7 +92,12 @@ def _group_norm(self, node: fx.Node) -> relax.Var: ) def _upsample_impl( - self, x: relax.Expr, size, scale_factor, method: str, align_corners: bool, + self, + x: relax.Expr, + size, + scale_factor, + method: str, + align_corners: bool, ) -> relax.Var: coord_trans = "align_corners" if align_corners else "half_pixel" @@ -120,28 +125,39 @@ def _upsample_bilinear2d(self, node: fx.Node) -> relax.Var: node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", True) ) scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", 1) - return self._upsample_impl(x, size=size, scale_factor=scale_factor, - method="linear", align_corners=align_corners) + return self._upsample_impl( + x, size=size, scale_factor=scale_factor, method="linear", align_corners=align_corners + ) def _upsample_nearest2d(self, node: fx.node) -> relax.Var: x = self.env[node.args[0]] size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) if size: - scale_factor = None # Can only define size or scale_factor, not both - align_corners = node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", None) - + scale_factor = None # Can only define size or scale_factor, not both + align_corners = ( + node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", None) + ) + else: - # TODO figure out why pytorch export passes a list such as - # [scale_factor,scale_factor] instead of just an int for + # TODO figure out why pytorch export passes a list such as + # [scale_factor,scale_factor] instead of just an int for # scale_factor. Using first element for now - scale_factor = node.args[2][0] if len(node.args) > 2 else node.kwargs.get("scale_factor", 1) - align_corners = node.args[3] if len(node.args) > 3 else node.kwargs.get("align_corners", None) - - return self._upsample_impl(x, size=size, scale_factor=scale_factor, - method="nearest_neighbor", - align_corners=align_corners) - + scale_factor = ( + node.args[2][0] if len(node.args) > 2 else node.kwargs.get("scale_factor", 1) + ) + align_corners = ( + node.args[3] if len(node.args) > 3 else node.kwargs.get("align_corners", None) + ) + + return self._upsample_impl( + x, + size=size, + scale_factor=scale_factor, + method="nearest_neighbor", + align_corners=align_corners, + ) + ########## Manipulation ########## def _select(self, node: fx.Node) -> relax.Var: diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 611d5d302248..faaca48a1bdf 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -27,8 +27,8 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): """ - This util ensures that a torch module can successfully be exported to TVM - using torch.export and that the resuling IR program gives the same result + This util ensures that a torch module can successfully be exported to TVM + using torch.export and that the resuling IR program gives the same result as PyTorch when ran on CUDA. """ torch_data = torch.from_numpy(raw_data) @@ -36,15 +36,12 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): with torch.no_grad(): exported_program = export(torch_module, example_args) - mod_from_torch = from_exported_program( - exported_program, keep_params_as_input=True - ) + mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) target = tvm.target.Target.from_device(tvm.cuda()) - ex = relax.build(tvm_mod, target=target, - relax_pipeline=relax.get_default_pipeline(target)) + ex = relax.build(tvm_mod, target=target, relax_pipeline=relax.get_default_pipeline(target)) dev = tvm.device("cuda", 0) vm = relax.VirtualMachine(ex, dev) @@ -55,47 +52,42 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): pytorch_out = torch_module(torch_data).detach().numpy() actual = gpu_out[0].numpy() desired = pytorch_out - np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, - atol=1e-5) + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) def test_upsample_with_size(): - """ - The Upsample module can be used with the size arugment or the scale + """ + The Upsample module can be used with the size arugment or the scale factor argument but not both. This tests the former. """ batch_size = 1 channels = 3 height, width = 8, 8 - torch_module = Upsample( - size=(64, 64), - mode='nearest', - recompute_scale_factor=None) + torch_module = Upsample(size=(64, 64), mode="nearest", recompute_scale_factor=None) - raw_data = np.random.rand( - batch_size, channels, height, width).astype("float32") + raw_data = np.random.rand(batch_size, channels, height, width).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module) def test_upsample_with_scale_factor(): """ - The Upsample module can be used with the size arugment or the scale - factor argument but not both. This tests the latter. + The Upsample module can be used with the size arugment or the scale + factor argument but not both. This tests the latter. """ - batch_size = 2 + batch_size = 2 channels = 3 height, width = 32, 32 - torch_module = Upsample(size=None, scale_factor = 7, mode = 'nearest', - align_corners=None, recompute_scale_factor=True) + torch_module = Upsample( + size=None, scale_factor=7, mode="nearest", align_corners=None, recompute_scale_factor=True + ) - raw_data = np.random.rand( - batch_size, channels, height, width).astype("float32") + raw_data = np.random.rand(batch_size, channels, height, width).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module) if __name__ == "__main__": - tvm.testing.main() \ No newline at end of file + tvm.testing.main() diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 8ca335c2fe7a..77aac527bc06 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -82,7 +82,7 @@ def forward(self, input): class expected: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = relax_op(input_1) @@ -112,7 +112,7 @@ def forward(self, input): class expected: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="bool") = relax_op(input_1) @@ -135,7 +135,7 @@ def forward(self, input): class expected_clamp: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -163,7 +163,7 @@ def forward(self, input): class expected_dropout: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -191,7 +191,7 @@ def forward(self, input): class expected_gelu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -220,7 +220,7 @@ def forward(self, input): class expected_hardsigmoid: @R.function def main( - inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) @@ -252,7 +252,7 @@ def forward(self, input): class expected1: @R.function def main( - inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) @@ -294,7 +294,7 @@ def forward(self, input): class expected_relu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -323,7 +323,7 @@ def forward(self, input): class expected_sigmoid: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -352,7 +352,7 @@ def forward(self, input): class expected_silu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -388,7 +388,7 @@ def forward(self, input): class expected1: @R.function def main( - inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( @@ -425,7 +425,7 @@ def forward(self, input): class expected: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -456,7 +456,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -487,7 +487,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -512,7 +512,7 @@ def forward(self, input): class expected_tril: @R.function def main( - input_1: R.Tensor((10, 10), dtype="float32") + input_1: R.Tensor((10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -531,7 +531,7 @@ def forward(self, input): class expected_triu: @R.function def main( - input_1: R.Tensor((10, 10), dtype="float32") + input_1: R.Tensor((10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -795,7 +795,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -883,7 +883,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -1580,7 +1580,7 @@ def forward(self, x, y): class Expected1: @R.function def main( - inp_0: R.Tensor((4, 4), dtype="float32") + inp_0: R.Tensor((4, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((), dtype="float32")): with R.dataflow(): lv: R.Tensor((), dtype="float32") = R.einsum((inp_0,), subscripts="ii") @@ -1827,7 +1827,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -1856,7 +1856,7 @@ def forward(self, input): class expected2: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -1885,7 +1885,7 @@ def forward(self, input): class expected3: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")): # block 0 with R.dataflow(): @@ -2007,9 +2007,7 @@ def forward(self, data): @tvm.script.ir_module class expected1: @R.function - def main( - input_1: R.Tensor((3, 3, 10, 10), dtype="float32") - ) -> R.Tuple( + def main(input_1: R.Tensor((3, 3, 10, 10), dtype="float32")) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -2051,9 +2049,7 @@ def forward(self, data): @tvm.script.ir_module class expected2: @R.function - def main( - input_1: R.Tensor((3, 3, 10, 10), dtype="float32") - ) -> R.Tuple( + def main(input_1: R.Tensor((3, 3, 10, 10), dtype="float32")) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -2102,7 +2098,7 @@ def forward(self, input): class expected_bilinear: @R.function def main( - input: R.Tensor((1, 3, 112, 112), dtype="float32") + input: R.Tensor((1, 3, 112, 112), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): # block 0 with R.dataflow(): @@ -2131,7 +2127,7 @@ def forward(self, input): class expected_nearest: @R.function def main( - input: R.Tensor((1, 3, 112, 112), dtype="float32") + input: R.Tensor((1, 3, 112, 112), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): # block 0 with R.dataflow(): @@ -2170,7 +2166,7 @@ def forward(self, input: torch.Tensor): class Expected1: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32") + inp_0: R.Tensor((256, 256), dtype="float32"), ) -> R.Tuple(R.Tensor((256,), dtype="float32")): with R.dataflow(): lv: R.Tensor((256,), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=False) @@ -2182,7 +2178,7 @@ def main( class Expected2: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32") + inp_0: R.Tensor((256, 256), dtype="float32"), ) -> R.Tuple(R.Tensor((256, 1), dtype="float32")): with R.dataflow(): lv: R.Tensor((256, 1), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=True) @@ -2204,7 +2200,7 @@ def forward(self, x): class expected1: @R.function def main( - inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -2238,7 +2234,7 @@ def forward(self, input): class expected_argmax1: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32") + inp_0: R.Tensor((256, 256), dtype="float32"), ) -> R.Tuple(R.Tensor((256,), dtype="int64")): with R.dataflow(): lv: R.Tensor((256,), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=False) @@ -2250,7 +2246,7 @@ def main( class expected_argmax2: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32") + inp_0: R.Tensor((256, 256), dtype="float32"), ) -> R.Tuple(R.Tensor((256, 1), dtype="int64")): with R.dataflow(): lv: R.Tensor((256, 1), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=True) @@ -2279,7 +2275,7 @@ def forward(self, input): class expected_argmin1: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32") + inp_0: R.Tensor((256, 256), dtype="float32"), ) -> R.Tuple(R.Tensor((), dtype="int64")): with R.dataflow(): lv: R.Tensor((), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=False) @@ -2291,7 +2287,7 @@ def main( class expected_argmin2: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32") + inp_0: R.Tensor((256, 256), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 1), dtype="int64")): with R.dataflow(): lv: R.Tensor((1, 1), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=True) @@ -2362,7 +2358,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 2, 3, 4), dtype="float32") + input_1: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="int32")): # block 0 with R.dataflow(): @@ -2388,7 +2384,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32") + x: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((4, 2, 3, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -2419,7 +2415,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 100), dtype="float32")): # block 0 with R.dataflow(): @@ -2445,7 +2441,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32") + x: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")): # block 0 with R.dataflow(): @@ -2483,7 +2479,7 @@ def main(x: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((6,), dtype="fl class expected2: @R.function def main( - x: R.Tensor((1, 3), dtype="float32") + x: R.Tensor((1, 3), dtype="float32"), ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")): # block 0 with R.dataflow(): @@ -2511,7 +2507,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32") + x: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")): # block 0 with R.dataflow(): @@ -2533,7 +2529,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 3, 10, 10), dtype="float32") + x: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 10, 3), dtype="float32")): # block 0 with R.dataflow(): @@ -2574,7 +2570,7 @@ def forward(self, x): class expected2: @R.function def main( - x: R.Tensor((8, 16), dtype="float32") + x: R.Tensor((8, 16), dtype="float32"), ) -> R.Tuple(R.Tensor((8, 1, 1, 16, 1), dtype="float32")): with R.dataflow(): lv: R.Tensor((8, 16), dtype="float32") = R.strided_slice( @@ -2619,9 +2615,7 @@ def forward(self, input): @tvm.script.ir_module class Expected: @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple( + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tuple( R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), @@ -2651,9 +2645,7 @@ def forward(self, data): @tvm.script.ir_module class expected1: @R.function - def main( - input_1: R.Tensor((3, 3, 10, 10), dtype="float32") - ) -> R.Tuple( + def main(input_1: R.Tensor((3, 3, 10, 10), dtype="float32")) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -2695,9 +2687,7 @@ def forward(self, data): @tvm.script.ir_module class expected2: @R.function - def main( - input_1: R.Tensor((3, 3, 10, 10), dtype="float32") - ) -> R.Tuple( + def main(input_1: R.Tensor((3, 3, 10, 10), dtype="float32")) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -2749,7 +2739,7 @@ def forward(self, input): class Expected1: @R.function def main( - inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32"), ) -> R.Tuple(R.Tensor((3, 4, 1), dtype="float32")): with R.dataflow(): lv: R.Tensor((3, 4, 1), dtype="float32") = R.squeeze(inp_0, axis=[1]) @@ -2765,7 +2755,7 @@ def forward(self, input): class Expected2: @R.function def main( - inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32"), ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")): with R.dataflow(): lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(inp_0, axis=None) @@ -2796,7 +2786,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 3), dtype="float32") + x: R.Tensor((1, 3), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 6), dtype="float32")): # block 0 with R.dataflow(): @@ -2809,7 +2799,7 @@ def main( class expected2: @R.function def main( - x: R.Tensor((1, 3), dtype="float32") + x: R.Tensor((1, 3), dtype="float32"), ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")): # block 0 with R.dataflow(): @@ -2833,7 +2823,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32") + x: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")): # block 0 with R.dataflow(): @@ -2855,7 +2845,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -2872,7 +2862,7 @@ def forward(self, input): class expected2: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10, 1), dtype="float32")): # block 0 with R.dataflow(): @@ -2896,7 +2886,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32") + x: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")): # block 0 with R.dataflow(): @@ -2918,7 +2908,7 @@ def forward(self, input): class Expected: @R.function def main( - input: R.Tensor((10, 10), dtype="float32") + input: R.Tensor((10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((20,), dtype="int32")): with R.dataflow(): lv: R.Tensor((20,), dtype="int32") = R.arange(0, 20, 1, dtype="int32") @@ -2939,7 +2929,7 @@ def forward(self, input): class Expected: @R.function def main( - input: R.Tensor((10, 10), dtype="float32") + input: R.Tensor((10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): with R.dataflow(): gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (input,) @@ -2959,7 +2949,7 @@ def forward(self, input): class Expected: @R.function def main( - inp_0: R.Tensor((10, 10), dtype="float32") + inp_0: R.Tensor((10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((10, 10), dtype="float32") = R.zeros( @@ -2982,7 +2972,7 @@ def forward(self, input: torch.Tensor): class Expected: @R.function def main( - inp_0: R.Tensor((10, 10), dtype="float32") + inp_0: R.Tensor((10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((10, 10), dtype="float32") = R.full( @@ -3005,7 +2995,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3), dtype="float32") + x: R.Tensor((1, 2, 3), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 2, 3), dtype="float32")): # block 0 with R.dataflow(): @@ -3034,7 +3024,7 @@ def forward(self, x): class expected_float: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32") + x: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -3052,7 +3042,7 @@ def forward(self, x): class expected_half: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32") + x: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")): # block 0 with R.dataflow(): @@ -3070,7 +3060,7 @@ def forward(self, x): class expected_type: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32") + x: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -3086,7 +3076,7 @@ def forward(self, input): class expected_to1: @R.function def main( - inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")): with R.dataflow(): lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(inp_0, dtype="float16") @@ -3102,7 +3092,7 @@ def forward(self, input): class expected_to2: @R.function def main( - inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(inp_0, dtype="float32") @@ -3187,7 +3177,7 @@ def forward(self, x): class Expected: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32") + inp_0: R.Tensor((256, 256), dtype="float32"), ) -> R.Tensor((256, 256), dtype="float32"): with R.dataflow(): gv: R.Tensor((256, 256), dtype="float32") = inp_0 From af7e52808686ca97fdccd04599b21d3f2626b712 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 10 Mar 2025 03:44:08 -0400 Subject: [PATCH 4/7] updated cuda test syntax and ran Black Python formatter with version 22 --- .../relax/test_from_exported_to_cuda.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index faaca48a1bdf..cc317f05bde5 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -25,12 +25,13 @@ from torch.nn import Softmax, Upsample -def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): +def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev): """ This util ensures that a torch module can successfully be exported to TVM using torch.export and that the resuling IR program gives the same result as PyTorch when ran on CUDA. """ + raw_data_for_tvm = raw_data.copy() # In case the data is modified torch_data = torch.from_numpy(raw_data) example_args = (torch_data,) @@ -39,13 +40,14 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) - target = tvm.target.Target.from_device(tvm.cuda()) - ex = relax.build(tvm_mod, target=target, relax_pipeline=relax.get_default_pipeline(target)) - dev = tvm.device("cuda", 0) + relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda())) + # TODO try pipeline below? + # releax_pipeline = relax.backend.cuda.pipeline.get_default_pipeline(target) + ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline) vm = relax.VirtualMachine(ex, dev) - gpu_data = tvm.nd.array(raw_data, dev) + gpu_data = tvm.nd.array(raw_data_for_tvm, dev) gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] gpu_out = vm["main"](gpu_data, *gpu_params) @@ -55,7 +57,8 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) -def test_upsample_with_size(): +@tvm.testing.parametrize_targets("cuda") +def test_upsample_with_size(target, dev): """ The Upsample module can be used with the size arugment or the scale factor argument but not both. This tests the former. @@ -68,10 +71,11 @@ def test_upsample_with_size(): raw_data = np.random.rand(batch_size, channels, height, width).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -def test_upsample_with_scale_factor(): +@tvm.testing.parametrize_targets("cuda") +def test_upsample_with_scale_factor(target, dev): """ The Upsample module can be used with the size arugment or the scale factor argument but not both. This tests the latter. @@ -86,8 +90,7 @@ def test_upsample_with_scale_factor(): raw_data = np.random.rand(batch_size, channels, height, width).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module) - + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) if __name__ == "__main__": tvm.testing.main() From 5c7b722aaa4009f096148009ad935eb8861df19c Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 10 Mar 2025 03:44:51 -0400 Subject: [PATCH 5/7] ran Black Python formatter with version 22 --- .../torch/base_fx_graph_translator.py | 6 +++--- .../relax/test_from_exported_to_cuda.py | 1 + .../test_frontend_from_exported_program.py | 20 ++++++++++++++----- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index d4885c8b4875..003ceebec6ff 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -37,9 +37,9 @@ def __init__(self) -> None: self.env: Dict[fx.Node, relax.Expr] = {} self.params: Dict[torch.Tensor, relax.Expr] = {} self.block_builder: relax.BlockBuilder = None - self.convert_map: Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]] = ( - self.create_convert_map() - ) + self.convert_map: Dict[ + Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var] + ] = self.create_convert_map() ########## Utilities ########## diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index cc317f05bde5..246bdbebfafc 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -92,5 +92,6 @@ def test_upsample_with_scale_factor(target, dev): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 77aac527bc06..399739146359 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2007,7 +2007,9 @@ def forward(self, data): @tvm.script.ir_module class expected1: @R.function - def main(input_1: R.Tensor((3, 3, 10, 10), dtype="float32")) -> R.Tuple( + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -2049,7 +2051,9 @@ def forward(self, data): @tvm.script.ir_module class expected2: @R.function - def main(input_1: R.Tensor((3, 3, 10, 10), dtype="float32")) -> R.Tuple( + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -2615,7 +2619,9 @@ def forward(self, input): @tvm.script.ir_module class Expected: @R.function - def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tuple( + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), @@ -2645,7 +2651,9 @@ def forward(self, data): @tvm.script.ir_module class expected1: @R.function - def main(input_1: R.Tensor((3, 3, 10, 10), dtype="float32")) -> R.Tuple( + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -2687,7 +2695,9 @@ def forward(self, data): @tvm.script.ir_module class expected2: @R.function - def main(input_1: R.Tensor((3, 3, 10, 10), dtype="float32")) -> R.Tuple( + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), From 6825ce55a4ea66b62948023e03cad2d9e235bfd1 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 10 Mar 2025 04:31:35 -0400 Subject: [PATCH 6/7] restore unmodified frontend test --- .../test_frontend_from_exported_program.py | 116 +++++++++--------- 1 file changed, 58 insertions(+), 58 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 399739146359..8ca335c2fe7a 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -82,7 +82,7 @@ def forward(self, input): class expected: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = relax_op(input_1) @@ -112,7 +112,7 @@ def forward(self, input): class expected: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="bool") = relax_op(input_1) @@ -135,7 +135,7 @@ def forward(self, input): class expected_clamp: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -163,7 +163,7 @@ def forward(self, input): class expected_dropout: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -191,7 +191,7 @@ def forward(self, input): class expected_gelu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -220,7 +220,7 @@ def forward(self, input): class expected_hardsigmoid: @R.function def main( - inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"), + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) @@ -252,7 +252,7 @@ def forward(self, input): class expected1: @R.function def main( - inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"), + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) @@ -294,7 +294,7 @@ def forward(self, input): class expected_relu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -323,7 +323,7 @@ def forward(self, input): class expected_sigmoid: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -352,7 +352,7 @@ def forward(self, input): class expected_silu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -388,7 +388,7 @@ def forward(self, input): class expected1: @R.function def main( - inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"), + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( @@ -425,7 +425,7 @@ def forward(self, input): class expected: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -456,7 +456,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -487,7 +487,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -512,7 +512,7 @@ def forward(self, input): class expected_tril: @R.function def main( - input_1: R.Tensor((10, 10), dtype="float32"), + input_1: R.Tensor((10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -531,7 +531,7 @@ def forward(self, input): class expected_triu: @R.function def main( - input_1: R.Tensor((10, 10), dtype="float32"), + input_1: R.Tensor((10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -795,7 +795,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -883,7 +883,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -1580,7 +1580,7 @@ def forward(self, x, y): class Expected1: @R.function def main( - inp_0: R.Tensor((4, 4), dtype="float32"), + inp_0: R.Tensor((4, 4), dtype="float32") ) -> R.Tuple(R.Tensor((), dtype="float32")): with R.dataflow(): lv: R.Tensor((), dtype="float32") = R.einsum((inp_0,), subscripts="ii") @@ -1827,7 +1827,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -1856,7 +1856,7 @@ def forward(self, input): class expected2: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -1885,7 +1885,7 @@ def forward(self, input): class expected3: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")): # block 0 with R.dataflow(): @@ -2102,7 +2102,7 @@ def forward(self, input): class expected_bilinear: @R.function def main( - input: R.Tensor((1, 3, 112, 112), dtype="float32"), + input: R.Tensor((1, 3, 112, 112), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): # block 0 with R.dataflow(): @@ -2131,7 +2131,7 @@ def forward(self, input): class expected_nearest: @R.function def main( - input: R.Tensor((1, 3, 112, 112), dtype="float32"), + input: R.Tensor((1, 3, 112, 112), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): # block 0 with R.dataflow(): @@ -2170,7 +2170,7 @@ def forward(self, input: torch.Tensor): class Expected1: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32"), + inp_0: R.Tensor((256, 256), dtype="float32") ) -> R.Tuple(R.Tensor((256,), dtype="float32")): with R.dataflow(): lv: R.Tensor((256,), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=False) @@ -2182,7 +2182,7 @@ def main( class Expected2: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32"), + inp_0: R.Tensor((256, 256), dtype="float32") ) -> R.Tuple(R.Tensor((256, 1), dtype="float32")): with R.dataflow(): lv: R.Tensor((256, 1), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=True) @@ -2204,7 +2204,7 @@ def forward(self, x): class expected1: @R.function def main( - inp_0: R.Tensor((1, 2, 3, 4), dtype="float32"), + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -2238,7 +2238,7 @@ def forward(self, input): class expected_argmax1: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32"), + inp_0: R.Tensor((256, 256), dtype="float32") ) -> R.Tuple(R.Tensor((256,), dtype="int64")): with R.dataflow(): lv: R.Tensor((256,), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=False) @@ -2250,7 +2250,7 @@ def main( class expected_argmax2: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32"), + inp_0: R.Tensor((256, 256), dtype="float32") ) -> R.Tuple(R.Tensor((256, 1), dtype="int64")): with R.dataflow(): lv: R.Tensor((256, 1), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=True) @@ -2279,7 +2279,7 @@ def forward(self, input): class expected_argmin1: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32"), + inp_0: R.Tensor((256, 256), dtype="float32") ) -> R.Tuple(R.Tensor((), dtype="int64")): with R.dataflow(): lv: R.Tensor((), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=False) @@ -2291,7 +2291,7 @@ def main( class expected_argmin2: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32"), + inp_0: R.Tensor((256, 256), dtype="float32") ) -> R.Tuple(R.Tensor((1, 1), dtype="int64")): with R.dataflow(): lv: R.Tensor((1, 1), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=True) @@ -2362,7 +2362,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 2, 3, 4), dtype="float32"), + input_1: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="int32")): # block 0 with R.dataflow(): @@ -2388,7 +2388,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32"), + x: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((4, 2, 3, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -2419,7 +2419,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 100), dtype="float32")): # block 0 with R.dataflow(): @@ -2445,7 +2445,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32"), + x: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")): # block 0 with R.dataflow(): @@ -2483,7 +2483,7 @@ def main(x: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((6,), dtype="fl class expected2: @R.function def main( - x: R.Tensor((1, 3), dtype="float32"), + x: R.Tensor((1, 3), dtype="float32") ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")): # block 0 with R.dataflow(): @@ -2511,7 +2511,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32"), + x: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")): # block 0 with R.dataflow(): @@ -2533,7 +2533,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 3, 10, 10), dtype="float32"), + x: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 10, 3), dtype="float32")): # block 0 with R.dataflow(): @@ -2574,7 +2574,7 @@ def forward(self, x): class expected2: @R.function def main( - x: R.Tensor((8, 16), dtype="float32"), + x: R.Tensor((8, 16), dtype="float32") ) -> R.Tuple(R.Tensor((8, 1, 1, 16, 1), dtype="float32")): with R.dataflow(): lv: R.Tensor((8, 16), dtype="float32") = R.strided_slice( @@ -2749,7 +2749,7 @@ def forward(self, input): class Expected1: @R.function def main( - inp_0: R.Tensor((3, 1, 4, 1), dtype="float32"), + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") ) -> R.Tuple(R.Tensor((3, 4, 1), dtype="float32")): with R.dataflow(): lv: R.Tensor((3, 4, 1), dtype="float32") = R.squeeze(inp_0, axis=[1]) @@ -2765,7 +2765,7 @@ def forward(self, input): class Expected2: @R.function def main( - inp_0: R.Tensor((3, 1, 4, 1), dtype="float32"), + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")): with R.dataflow(): lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(inp_0, axis=None) @@ -2796,7 +2796,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 3), dtype="float32"), + x: R.Tensor((1, 3), dtype="float32") ) -> R.Tuple(R.Tensor((1, 6), dtype="float32")): # block 0 with R.dataflow(): @@ -2809,7 +2809,7 @@ def main( class expected2: @R.function def main( - x: R.Tensor((1, 3), dtype="float32"), + x: R.Tensor((1, 3), dtype="float32") ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")): # block 0 with R.dataflow(): @@ -2833,7 +2833,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32"), + x: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")): # block 0 with R.dataflow(): @@ -2855,7 +2855,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -2872,7 +2872,7 @@ def forward(self, input): class expected2: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10, 1), dtype="float32")): # block 0 with R.dataflow(): @@ -2896,7 +2896,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32"), + x: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")): # block 0 with R.dataflow(): @@ -2918,7 +2918,7 @@ def forward(self, input): class Expected: @R.function def main( - input: R.Tensor((10, 10), dtype="float32"), + input: R.Tensor((10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((20,), dtype="int32")): with R.dataflow(): lv: R.Tensor((20,), dtype="int32") = R.arange(0, 20, 1, dtype="int32") @@ -2939,7 +2939,7 @@ def forward(self, input): class Expected: @R.function def main( - input: R.Tensor((10, 10), dtype="float32"), + input: R.Tensor((10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): with R.dataflow(): gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (input,) @@ -2959,7 +2959,7 @@ def forward(self, input): class Expected: @R.function def main( - inp_0: R.Tensor((10, 10), dtype="float32"), + inp_0: R.Tensor((10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((10, 10), dtype="float32") = R.zeros( @@ -2982,7 +2982,7 @@ def forward(self, input: torch.Tensor): class Expected: @R.function def main( - inp_0: R.Tensor((10, 10), dtype="float32"), + inp_0: R.Tensor((10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((10, 10), dtype="float32") = R.full( @@ -3005,7 +3005,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3), dtype="float32"), + x: R.Tensor((1, 2, 3), dtype="float32") ) -> R.Tuple(R.Tensor((1, 2, 3), dtype="float32")): # block 0 with R.dataflow(): @@ -3034,7 +3034,7 @@ def forward(self, x): class expected_float: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32"), + x: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -3052,7 +3052,7 @@ def forward(self, x): class expected_half: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32"), + x: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")): # block 0 with R.dataflow(): @@ -3070,7 +3070,7 @@ def forward(self, x): class expected_type: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32"), + x: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -3086,7 +3086,7 @@ def forward(self, input): class expected_to1: @R.function def main( - inp_0: R.Tensor((1, 2, 3, 4), dtype="float32"), + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")): with R.dataflow(): lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(inp_0, dtype="float16") @@ -3102,7 +3102,7 @@ def forward(self, input): class expected_to2: @R.function def main( - inp_0: R.Tensor((1, 2, 3, 4), dtype="float32"), + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(inp_0, dtype="float32") @@ -3187,7 +3187,7 @@ def forward(self, x): class Expected: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32"), + inp_0: R.Tensor((256, 256), dtype="float32") ) -> R.Tensor((256, 256), dtype="float32"): with R.dataflow(): gv: R.Tensor((256, 256), dtype="float32") = inp_0 From 28b91d5652749d941f2d49fb679461ae4d001a3b Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 10 Mar 2025 12:14:15 -0400 Subject: [PATCH 7/7] ran Python Black formatter --- tests/python/relax/test_from_exported_to_cuda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 5129ca096147..69daab36a581 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -90,6 +90,7 @@ def test_upsample_with_scale_factor(target, dev): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + @tvm.testing.parametrize_targets("cuda") def test_linalg_vector_norm(target, dev): class VectorNorm0(torch.nn.Module):