Skip to content

Commit 2a9709c

Browse files
authored
[Unity][Frontend] FX exp and strided_slice fix (#14338)
* Add the support of `exp` for the FX translator. * Previously the way FX translator dealt with `None` in torch tensor slice (e.g., `x[:, None, None]`) is not right. This PR fixes this issue. Specifically, the `None` here means dim expansion, and the previous impl mistakenly increases the dim counter when seeing `None`, which will lead to dim counter out-of-range issue in the end.
1 parent 57b42a8 commit 2a9709c

File tree

2 files changed

+49
-6
lines changed

2 files changed

+49
-6
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ def _call_binary_op(self, op, lhs, rhs):
136136
def _cos(self, node: fx.node.Node) -> relax.Var:
137137
return self.block_builder.emit(relax.op.cos(self.env[node.args[0]]))
138138

139+
def _exp(self, node: fx.node.Node) -> relax.Var:
140+
return self.block_builder.emit(relax.op.exp(self.env[node.args[0]]))
141+
139142
def _sin(self, node: fx.node.Node) -> relax.Var:
140143
return self.block_builder.emit(relax.op.sin(self.env[node.args[0]]))
141144

@@ -858,8 +861,7 @@ def _getitem(self, node: fx.node.Node) -> relax.Var:
858861
axes.append(i)
859862
i = i + 1
860863
elif index is None:
861-
expand_dim.append(i)
862-
i = i + 1
864+
expand_dim.append(len(axes) + len(expand_dim))
863865
else:
864866
raise ValueError("Unsupported index type: " + str(type(index)))
865867
while i < len(shape):
@@ -903,6 +905,7 @@ def create_convert_map(self):
903905
nn.modules.sparse.Embedding: self._embedding,
904906
# call_function and call_method
905907
"cos": self._cos,
908+
"exp": self._exp,
906909
"sin": self._sin,
907910
"add": self._add,
908911
"floordiv": self._floordiv,

tests/python/relax/test_frontend_from_fx.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import tvm
2020
from tvm import relax
2121
import tvm.testing
22-
from tvm.script.parser import relax as R, tir as T
22+
from tvm.script.parser import ir as I, relax as R, tir as T
2323

2424

2525
def verify_model(torch_model, input_info, binding, expected):
@@ -1372,8 +1372,6 @@ def test_getitem():
13721372
torch.set_grad_enabled(False)
13731373
torch.random.manual_seed(0)
13741374

1375-
input_info = [([1, 3, 10, 10], "float32")]
1376-
13771375
class Slice1(Module):
13781376
def forward(self, x):
13791377
return x[0, 1::2, :, :3]
@@ -1398,7 +1396,29 @@ def main(
13981396
R.output(gv)
13991397
return gv
14001398

1401-
verify_model(Slice1(), input_info, {}, expected1)
1399+
class Slice2(Module):
1400+
def forward(self, x):
1401+
return x[:, None, None, :, None]
1402+
1403+
@I.ir_module
1404+
class expected2:
1405+
@R.function
1406+
def main(
1407+
inp_0: R.Tensor((8, 16), dtype="float32")
1408+
) -> R.Tensor((8, 1, 1, 16, 1), dtype="float32"):
1409+
with R.dataflow():
1410+
lv: R.Tensor((8, 16), dtype="float32") = R.strided_slice(
1411+
inp_0, axes=[0, 1], begin=[0, 0], end=[8, 16], strides=[1, 1]
1412+
)
1413+
lv1: R.Tensor((8, 1, 1, 16, 1), dtype="float32") = R.reshape(
1414+
lv, R.shape([8, 1, 1, 16, 1])
1415+
)
1416+
gv: R.Tensor((8, 1, 1, 16, 1), dtype="float32") = lv1
1417+
R.output(gv)
1418+
return gv
1419+
1420+
verify_model(Slice1(), [([1, 3, 10, 10], "float32")], {}, expected1)
1421+
verify_model(Slice2(), [([8, 16], "float32")], {}, expected2)
14021422

14031423

14041424
@tvm.testing.requires_gpu
@@ -1451,6 +1471,26 @@ def main(
14511471

14521472
verify_model(Cos(), input_info, {}, expected2)
14531473

1474+
# exp
1475+
class Exp(Module):
1476+
def forward(self, input):
1477+
return torch.exp(input)
1478+
1479+
@tvm.script.ir_module
1480+
class expected_exp:
1481+
@R.function
1482+
def main(
1483+
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
1484+
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
1485+
# block 0
1486+
with R.dataflow():
1487+
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
1488+
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
1489+
R.output(gv)
1490+
return gv
1491+
1492+
verify_model(Exp(), input_info, {}, expected_exp)
1493+
14541494
# sqrt
14551495
class Sqrt(Module):
14561496
def forward(self, input):

0 commit comments

Comments
 (0)