diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index a89726495e5a..33f6ffc3132e 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1501,6 +1501,12 @@ def _to(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.astype(x, dtype)) return x + def _type_as(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + other = self.env[node.args[1]] + dtype = other.struct_info.dtype + return self.block_builder.emit(relax.op.astype(x, dtype)) + ########## Others ########## def _getitem(self, node: fx.Node) -> relax.Var: @@ -1584,6 +1590,16 @@ def _getitem(self, node: fx.Node) -> relax.Var: else: assert False + def _item(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + return self.block_builder.emit(relax.op.take(x, relax.const(0, "int64"), axis=0)) + + def _zeros_inplace(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + output = self.block_builder.emit(relax.op.zeros_like(x)) + self.env[node.args[0]] = output + return output + @abc.abstractmethod def create_convert_map( self, diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index cdf0c46bb5ef..0434712050ed 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -253,6 +253,14 @@ def _one_hot(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.one_hot(x, on_value, off_value, num_classes, axis)) + def _zeros(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) + return self.block_builder.emit(relax.op.zeros(size, dtype)) + ########## Others ########## def create_convert_map( @@ -470,11 +478,18 @@ def create_convert_map( "new_ones.default": self._new_ones, "one_hot.default": self._one_hot, "ones.default": self._ones, + "ones_like.default": lambda node: self.block_builder.emit( + relax.op.ones_like(self.env[node.args[0]]) + ), + "zero_.default": self._zeros_inplace, + "zeros.default": self._zeros, # datatype "to.dtype": self._to, "to.dtype_layout": self._to, + "type_as.default": self._type_as, # other "getitem": self._getitem, + "item.default": self._item, } def create_input_vars( diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index c3bf8f045410..55abf20fcc03 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -836,7 +836,11 @@ def create_convert_map( "new_ones": self._new_ones, "ones": self._ones, "one_hot": self._one_hot, + "ones_like": lambda node: self.block_builder.emit( + relax.op.ones_like(self.env[node.args[0]]) + ), "tensor": self._tensor, + "zero_": self._zeros_inplace, "copy_": self._inplace_copy, # datatype "astype": self._type, @@ -845,10 +849,12 @@ def create_convert_map( "is_floating_point": self._is_floating_point, "to": self._to, "type": self._type, + "type_as": self._type_as, # other "getattr": self._getattr, "getitem": self._getitem, "sym_size.int": self._sym_size_int, + "item": self._item, } def update_convert_map(self, custom_convert_map: dict): diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index c6ead5aaccfb..108617991b1f 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3948,6 +3948,98 @@ def main( verify_model(OneHot(), example_args, {}, Expected) +def test_ones_like(): + class OnesLike(Module): + def forward(self, input): + return torch.ones_like(input) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((128, 128), dtype="float32") + ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((128, 128), dtype="float32") = R.ones_like(input, dtype="void") + gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.rand(128, 128, dtype=torch.float32),) + + verify_model(OnesLike(), example_args, {}, Expected) + + +def test_zero_inplace(): + class ZeroInplace(Module): + def forward(self, input): + return input.zero_() + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((128, 128), dtype="float32") + ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like(input, dtype="void") + gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.rand(128, 128, dtype=torch.float32),) + + verify_model(ZeroInplace(), example_args, {}, Expected) + + +def test_zeros(): + class Zeros(Module): + def forward(self, input): + return torch.zeros(5, 2) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((128, 128), dtype="float32") + ) -> R.Tuple(R.Tensor((5, 2), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((5, 2), dtype="float32") = R.zeros(R.shape([5, 2]), dtype="float32") + gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.rand(128, 128, dtype=torch.float32),) + + verify_model(Zeros(), example_args, {}, Expected) + + +def test_type_as(): + class TypeAs(Module): + def forward(self, input, other): + return input.type_as(other) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((128, 128), dtype="float32"), + other: R.Tensor((128, 128), dtype="float16"), + ) -> R.Tuple(R.Tensor((128, 128), dtype="float16")): + with R.dataflow(): + lv: R.Tensor((128, 128), dtype="float16") = R.astype(input, dtype="float16") + gv: R.Tuple(R.Tensor((128, 128), dtype="float16")) = (lv,) + R.output(gv) + return gv + + example_args = ( + torch.rand(128, 128, dtype=torch.float32), + torch.rand(128, 128, dtype=torch.float16), + ) + + verify_model(TypeAs(), example_args, {}, Expected) + + def test_select(): class Select(Module): def forward(self, input): @@ -4379,6 +4471,25 @@ def main( verify_model(Narrow(), example_args, {}, Expected) +def test_item(): + class Item(Module): + def forward(self, x): + return x.item() + + @tvm.script.ir_module + class Expected: + @R.function + def main(input: R.Tensor((1,), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.take(input, R.const(0, "int64"), axis=0) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, dtype=torch.float32),) + verify_model(Item(), example_args, {}, Expected) + + def test_norm(): class Norm(Module): def __init__(self, p, dim=None, keepdim=False): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index f21cde6df23c..cb69398e0a00 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4506,6 +4506,95 @@ def main( verify_model(EmptyLike(), [([5], "float32")], {}, Expected) +def test_ones_like(): + class OnesLike(Module): + def forward(self, data): + return torch.ones_like(data) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((128, 128), dtype="float32") + ) -> R.Tensor((128, 128), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((128, 128), dtype="float32") = R.ones_like(inp_0, dtype="void") + gv: R.Tensor((128, 128), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(OnesLike(), [([128, 128], "float32")], {}, Expected) + + +def test_zero_inplace(): + class ZeroInplace(Module): + def forward(self, data): + return data.zero_() + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((128, 128), dtype="float32") + ) -> R.Tensor((128, 128), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like(inp_0, dtype="void") + gv: R.Tensor((128, 128), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(ZeroInplace(), [([128, 128], "float32")], {}, Expected) + + +def test_type_as(): + class TypeAs(Module): + def forward(self, data, other): + return data.type_as(other) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((128, 128), dtype="float16"), + inp_1: R.Tensor((128, 128), dtype="float32"), + ) -> R.Tensor((128, 128), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((128, 128), dtype="float32") = R.astype(inp_0, dtype="float32") + gv: R.Tensor((128, 128), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(TypeAs(), [([128, 128], "float16"), ([128, 128], "float32")], {}, Expected) + + +def test_item(): + class Item(Module): + def forward(self, data): + return data.item() + + @tvm.script.ir_module + class Expected: + @R.function + def main(inp_0: R.Tensor((1,), dtype="float32")) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.take(inp_0, R.const(0, "int64"), axis=0) + gv: R.Tensor((), dtype="float32") = lv + R.output(gv) + return gv + + verify_model( + Item(), + [ + ( + [1], + "float32", + ) + ], + {}, + Expected, + ) + + def test_numel(): class Numel(Module): def forward(self, data):