Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,6 +1549,25 @@ def _new_ones(self, node: fx.Node) -> relax.Var:
)
)

def _new_zeros(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
input_tensor = args[0]
size = (
args[1]
if isinstance(args[1], (list, tuple))
else (args[1],)
if len(args[1:]) == 1
else args[1:]
)
size = relax.ShapeExpr(size)
return self.block_builder.emit(
relax.op.full(
size,
relax.const(0, input_tensor.struct_info.dtype),
input_tensor.struct_info.dtype,
)
)

def _ones(self, node: fx.Node) -> relax.Var:
import torch

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def create_convert_map(
"linspace.default": self._linspace,
"masked_fill.Scalar": self._masked_fill,
"new_ones.default": self._new_ones,
"new_zeros.default": self._new_zeros,
"one_hot.default": self._one_hot,
"ones.default": self._ones,
"ones_like.default": lambda node: self.block_builder.emit(
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,7 @@ def create_convert_map(
"masked_fill": self._masked_fill,
"masked_scatter": self._masked_scatter,
"new_ones": self._new_ones,
"new_zeros": self._new_zeros,
"ones": self._ones,
"one_hot": self._one_hot,
"ones_like": lambda node: self.block_builder.emit(
Expand Down
23 changes: 23 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -3751,6 +3751,29 @@ def main(
verify_model(NewOnes(), example_args, {}, expected1)


def test_new_zeros():
class NewZeros(torch.nn.Module):
def forward(self, x):
return x.new_zeros(1, 128, 128)

@tvm.script.ir_module
class expected1:
@R.function
def main(
x: R.Tensor((1, 128, 128), dtype="float32")
) -> R.Tuple(R.Tensor((1, 128, 128), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 128, 128), dtype="float32") = R.full(
R.shape([1, 128, 128]), R.const(0, "float32"), dtype="float32"
)
gv: R.Tuple(R.Tensor((1, 128, 128), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.randn(1, 128, 128, dtype=torch.float32),)
verify_model(NewZeros(), example_args, {}, expected1)


def test_to_copy():
# float
class ToFloat(Module):
Expand Down
25 changes: 25 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3327,6 +3327,31 @@ def main(x: R.Tensor((1, 2, 3), dtype="float32")) -> R.Tensor((1, 2, 3), dtype="
verify_model(NewOnes(), input_info, {}, expected1)


def test_new_zeros():
input_info = [([1, 128, 128], "float32")]

class NewZeros(Module):
def forward(self, x):
return x.new_zeros(1, 128, 128)

@tvm.script.ir_module
class expected:
@R.function
def main(
x: R.Tensor((1, 128, 128), dtype="float32")
) -> R.Tensor((1, 128, 128), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 128, 128), dtype="float32") = R.full(
(1, 128, 128), R.const(0.0, "float32"), dtype="float32"
)
gv: R.Tensor((1, 128, 128), dtype="float32") = lv
R.output(gv)
return gv

verify_model(NewZeros(), input_info, {}, expected)


def test_expand():
input_info = [([1, 2, 3, 4], "float32")]

Expand Down