diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 5ed0f18deb9e..ab3198f1d0a7 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -518,6 +518,14 @@ def _baddbmm(self, node: fx.node.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) return res + def _einsum(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.einsum(tuple(args[1]), args[0])) + return self.block_builder.emit(relax.op.einsum(args[1:], args[0])) + ########## Manipulation ########## def _cat(self, node: fx.node.Node) -> relax.Var: @@ -1478,6 +1486,7 @@ def create_convert_map(self): "max": self._max, "cross_entropy": self._cross_entropy, "scaled_dot_product_attention": self._scaled_dot_product_attention, + "einsum": self._einsum, } def update_convert_map(self, custom_convert_map: dict): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index dd2719f8ce91..be7edc913cc2 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -650,6 +650,49 @@ def main( ) +def test_einsum(): + class Einsum1(Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.einsum("ii", x) + + class Einsum2(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.einsum("i,j->ij", x, y) + + @tvm.script.ir_module + class Expected1: + @R.function + def main(inp_0: R.Tensor((4, 4), dtype="float32")) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.einsum((inp_0,), subscripts="ii") + gv: R.Tensor((), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((5,), dtype="float32"), inp_1: R.Tensor((4,), dtype="float32") + ) -> R.Tensor((5, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((5, 4), dtype="float32") = R.einsum( + (inp_0, inp_1), subscripts="i,j->ij" + ) + gv: R.Tensor((5, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Einsum1(), [([4, 4], "float32")], {}, Expected1) + verify_model(Einsum2(), [([5], "float32"), ([4], "float32")], {}, Expected2) + + def test_relu(): class ReLU0(Module): def __init__(self):