Skip to content

Commit 4ef582a

Browse files
authored
[Relax][PyTorch] Add support for linspace op in fx graph (#17915)
* Update fx_translator.py * Update test_frontend_from_fx.py
1 parent 3f27aa8 commit 4ef582a

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,7 @@ def create_convert_map(
835835
"fill_": self._inplace_fill,
836836
"full": self._full,
837837
"index_select": self._index_select,
838+
"linspace": self._linspace,
838839
"masked_fill_": self._inplace_masked_fill,
839840
"masked_fill": self._masked_fill,
840841
"masked_scatter": self._masked_scatter,

tests/python/relax/test_frontend_from_fx.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5396,5 +5396,23 @@ def forward(self, input):
53965396
)
53975397

53985398

5399+
def test_linspace():
5400+
import numpy as np
5401+
5402+
class Linspace(Module):
5403+
def forward(self, input):
5404+
return torch.linspace(0, 1, steps=9)
5405+
5406+
graph_model = fx.symbolic_trace(Linspace())
5407+
mod = from_fx(graph_model, [([9, 9], "float32")])
5408+
assert len(mod["main"].body.blocks) == 1
5409+
assert len(mod["main"].body.blocks[0].bindings) == 1
5410+
assert isinstance(mod["main"].body.blocks[0].bindings[0].value, relax.Constant)
5411+
tvm.testing.assert_allclose(
5412+
mod["main"].body.blocks[0].bindings[0].value.data.numpy(),
5413+
np.linspace(0, 1, num=9, dtype="float32"),
5414+
)
5415+
5416+
53995417
if __name__ == "__main__":
54005418
tvm.testing.main()

0 commit comments

Comments
 (0)