diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 003ceebec6ff..a9f54d91e3ce 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1018,6 +1018,10 @@ def _empty(self, node: fx.Node) -> relax.Var: dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) return self.block_builder.emit(relax.op.zeros(node.args[0], dtype)) + def _empty_like(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + return self.block_builder.emit(relax.op.zeros_like(x)) + def _fill(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index ef98d3c02501..29d959818f21 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -409,6 +409,11 @@ def _flatten_module(self, node: fx.Node) -> relax.Var: end_dim = module.end_dim return self._flatten_impl(x, start_dim, end_dim) + def _numel(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + shape = self.shape_of(x) + return relax.const(reduce(lambda x, y: x * y, [s.value for s in shape]), "int32") + def _size(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) @@ -511,6 +516,18 @@ def _ones(self, node: fx.Node) -> relax.Var: ) ) + def _one_hot(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + num_classes = node.args[1] if len(node.args) > 1 else node.kwargs.get("num_classes") + if num_classes is None: + raise ValueError("num_classes not found in node.args or node.kwargs") + on_value = node.args[2] if len(node.args) > 2 else node.kwargs.get("on_value", 1) + off_value = node.args[3] if len(node.args) > 3 else node.kwargs.get("off_value", 0) + axis = node.args[4] if len(node.args) > 4 else node.kwargs.get("axis", -1) + on_value = relax.PrimValue(on_value) + off_value = relax.PrimValue(off_value) + return self.block_builder.emit(relax.op.one_hot(x, on_value, off_value, num_classes, axis)) + def _tensor(self, node: fx.Node) -> relax.Var: dtype = node.kwargs.get("dtype", None) if isinstance(node.args[0], float): @@ -735,6 +752,7 @@ def create_convert_map( "flatten": self._flatten, "flip": self._flip, "gather": self._gather, + "numel": self._numel, "permute": self._permute, "repeat": self._repeat, "reshape": self._reshape, @@ -753,6 +771,7 @@ def create_convert_map( # tensor creation "arange": self._arange, "empty": self._empty, + "empty_like": self._empty_like, "fill_": self._inplace_fill, "full": self._full, "index_select": self._index_select, @@ -761,6 +780,7 @@ def create_convert_map( "masked_scatter": self._masked_scatter, "new_ones": self._new_ones, "ones": self._ones, + "one_hot": self._one_hot, "tensor": self._tensor, # datatype "astype": self._type, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 0b4b34e0c9bb..020fc8f5b3c2 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4037,5 +4037,67 @@ def main( verify_model(Take(), [([5], "float32"), ([3], "int32")], {}, Expected) +def test_one_hot(): + class OneHot(Module): + def forward(self, indices): + return torch.nn.functional.one_hot(indices, num_classes=10) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5,), dtype="int32"), + ) -> R.Tensor((5, 10), dtype="int64"): + with R.dataflow(): + lv: R.Tensor((5, 10), dtype="int64") = R.one_hot( + inp_0, R.prim_value(1), R.prim_value(0), depth=10, axis=-1 + ) + gv: R.Tensor((5, 10), dtype="int64") = lv + R.output(gv) + + return gv + + verify_model(OneHot(), [([5], "int32")], {}, Expected) + + +def test_empty_like(): + class EmptyLike(Module): + def forward(self, data): + return torch.empty_like(data) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5,), dtype="float32"), + ) -> R.Tensor((5,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((5,), dtype="float32") = R.zeros_like(inp_0) + gv: R.Tensor((5,), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(EmptyLike(), [([5], "float32")], {}, Expected) + + +def test_numel(): + class Numel(Module): + def forward(self, data): + return torch.numel(data) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((), dtype="int32"): + with R.dataflow(): + gv: R.Tensor((), dtype="int32") = R.const(15, "int32") + R.output(gv) + return gv + + verify_model(Numel(), [([5, 3], "float32")], {}, Expected) + + if __name__ == "__main__": tvm.testing.main()