Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,14 @@ def _baddbmm(self, node: fx.node.Node) -> relax.Var:
res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias))
return res

def _einsum(self, node: fx.node.Node) -> relax.Var:
import torch # type: ignore

args = self.retrieve_args(node)
if isinstance(args[1], (torch.Size, tuple, list)):
return self.block_builder.emit(relax.op.einsum(tuple(args[1]), args[0]))
return self.block_builder.emit(relax.op.einsum(args[1:], args[0]))

########## Manipulation ##########

def _cat(self, node: fx.node.Node) -> relax.Var:
Expand Down Expand Up @@ -1478,6 +1486,7 @@ def create_convert_map(self):
"max": self._max,
"cross_entropy": self._cross_entropy,
"scaled_dot_product_attention": self._scaled_dot_product_attention,
"einsum": self._einsum,
}

def update_convert_map(self, custom_convert_map: dict):
Expand Down
43 changes: 43 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,49 @@ def main(
)


def test_einsum():
class Einsum1(Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.einsum("ii", x)

class Einsum2(Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.einsum("i,j->ij", x, y)

@tvm.script.ir_module
class Expected1:
@R.function
def main(inp_0: R.Tensor((4, 4), dtype="float32")) -> R.Tensor((), dtype="float32"):
with R.dataflow():
lv: R.Tensor((), dtype="float32") = R.einsum((inp_0,), subscripts="ii")
gv: R.Tensor((), dtype="float32") = lv
R.output(gv)
return gv

@tvm.script.ir_module
class Expected2:
@R.function
def main(
inp_0: R.Tensor((5,), dtype="float32"), inp_1: R.Tensor((4,), dtype="float32")
) -> R.Tensor((5, 4), dtype="float32"):
with R.dataflow():
lv: R.Tensor((5, 4), dtype="float32") = R.einsum(
(inp_0, inp_1), subscripts="i,j->ij"
)
gv: R.Tensor((5, 4), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Einsum1(), [([4, 4], "float32")], {}, Expected1)
verify_model(Einsum2(), [([5], "float32"), ([4], "float32")], {}, Expected2)


def test_relu():
class ReLU0(Module):
def __init__(self):
Expand Down