Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,29 @@ def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
self.env[node.args[0]] = output
return output

def _linspace(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
start = args[0]
stop = args[1]
step = args[2]

if step != 1:
step = (stop - start) / (step - 1)
stop = stop + (step / 2)
else:
stop = start + step

if len(args) <= 3 or args[3] is None:
import torch

dtype = self._convert_data_type(str(torch.get_default_dtype()))
else:
dtype = self._convert_data_type(args[3])

return self.block_builder.emit(
relax.op.arange(start=start, end=stop, step=step, dtype=dtype)
)

def _masked_fill(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
mask = self.env[node.args[1]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ def create_convert_map(
"full_like.default": self._full_like,
"index_select.default": self._index_select,
"lift_fresh_copy.default": self._to_copy,
"linspace.default": self._linspace,
"masked_fill.Scalar": self._masked_fill,
"new_ones.default": self._new_ones,
"one_hot.default": self._one_hot,
Expand Down
21 changes: 21 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -4642,5 +4642,26 @@ def main(
verify_model(Eye2(), example_args2, {}, Expected2)


def test_linspace():
class Linspace(Module):
def forward(self, input):
return torch.linspace(0, 1, steps=9, dtype=torch.float32)

@tvm.script.ir_module
class Expected:
@R.function
def main(
input: R.Tensor((9, 9), dtype="float32")
) -> R.Tuple(R.Tensor((9,), dtype="float32")):
with R.dataflow():
lv: R.Tensor((9,), dtype="float32") = R.arange(0, 1.0625, 0.125, dtype="float32")
gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.randn(9, 9, dtype=torch.float32),)
verify_model(Linspace(), example_args, {}, Expected)


if __name__ == "__main__":
tvm.testing.main()
Loading