diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 4ce899685a7e..003ceebec6ff 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -847,6 +847,21 @@ def _expand(self, node: fx.Node) -> relax.Var: broadcast_shape.append(i) return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) + def _flip(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dims = node.args[1] if len(node.args) > 1 else node.kwargs.get("dims", None) + if isinstance(dims, (list, tuple)) and len(dims) > 0: + dims = dims[0] + elif not isinstance(dims, int): + raise TypeError(f"flip expects an integer axis, but got {type(dims)}: {dims}") + return self.block_builder.emit(relax.op.flip(x, dims)) + + def _gather(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + index = self.env[node.args[2]] + return self.block_builder.emit(relax.op.gather_elements(x, index, axis=dim)) + def _permute(self, node: fx.Node) -> relax.Var: import torch # type: ignore @@ -921,6 +936,12 @@ def _stack(self, node: fx.Node) -> relax.Var: s_shape.append(s) return self.block_builder.emit(relax.op.reshape(cat, s_shape)) + def _take(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + indices = self.env[node.args[1]] + indices = self.block_builder.emit(relax.op.astype(indices, "int32")) + return self.block_builder.emit(relax.op.take(x, indices)) + def _tile(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index af84f71bbf1e..ef98d3c02501 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -733,6 +733,8 @@ def create_convert_map( "cumsum": self._cumsum, "expand": self._expand, "flatten": self._flatten, + "flip": self._flip, + "gather": self._gather, "permute": self._permute, "repeat": self._repeat, "reshape": self._reshape, @@ -741,6 +743,7 @@ def create_convert_map( "split": self._split, "squeeze": self._squeeze, "stack": self._stack, + "take": self._take, "tile": self._tile, "transpose": self._transpose, "unsqueeze": lambda node: self.block_builder.emit( diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index e9fa7965315a..0b4b34e0c9bb 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3903,5 +3903,139 @@ def main(inp_0: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((), dtype="bool") verify_model(IsFloatingPoint(), [([2, 3], "float32")], {}, Expected) +def test_gather(): + class Gather0(Module): + def forward(self, data, indices): + return torch.gather(data, 0, indices) + + class Gather1(Module): + def forward(self, data, indices): + return torch.gather(data, 1, indices) + + class Gather2(Module): + def forward(self, data, indices): + return torch.gather(data, -1, indices) + + class Gather3(Module): + def forward(self, data, indices): + return torch.gather(data, -2, indices) + + @tvm.script.ir_module + class Expected0: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + inp_1: R.Tensor((2, 3), dtype="int32"), + ) -> R.Tensor((2, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=0) + gv: R.Tensor((2, 3), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + inp_1: R.Tensor((2, 3), dtype="int32"), + ) -> R.Tensor((2, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=1) + gv: R.Tensor((2, 3), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + inp_1: R.Tensor((2, 3), dtype="int32"), + ) -> R.Tensor((2, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=-1) + gv: R.Tensor((2, 3), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected3: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + inp_1: R.Tensor((2, 3), dtype="int32"), + ) -> R.Tensor((2, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=-2) + gv: R.Tensor((2, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Gather0(), [([2, 3], "float32"), ([2, 3], "int32")], {}, Expected0) + verify_model(Gather1(), [([2, 3], "float32"), ([2, 3], "int32")], {}, Expected1) + verify_model(Gather2(), [([2, 3], "float32"), ([2, 3], "int32")], {}, Expected2) + verify_model(Gather3(), [([2, 3], "float32"), ([2, 3], "int32")], {}, Expected3) + + +def test_flip(): + class Flip0(Module): + def forward(self, data): + return torch.flip(data, [0]) + + class Flip1(Module): + def forward(self, data): + return torch.flip(data, [1]) + + @tvm.script.ir_module + class Expected0: + @R.function + def main( + inp_0: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 2), dtype="float32") = R.flip(inp_0, axis=0) + gv: R.Tensor((2, 2), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 2), dtype="float32") = R.flip(inp_0, axis=1) + gv: R.Tensor((2, 2), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Flip0(), [([2, 2], "float32")], {}, Expected0) + verify_model(Flip1(), [([2, 2], "float32")], {}, Expected1) + + +def test_take(): + class Take(Module): + def forward(self, data, indices): + return torch.take(data, indices) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5,), dtype="float32"), + inp_1: R.Tensor((3,), dtype="int32"), + ) -> R.Tensor((3,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((3,), dtype="int32") = R.astype(inp_1, "int32") + lv1: R.Tensor((3,), dtype="float32") = R.take(inp_0, lv) + gv: R.Tensor((3,), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_model(Take(), [([5], "float32"), ([3], "int32")], {}, Expected) + + if __name__ == "__main__": tvm.testing.main()