Skip to content

Commit ed477b0

Browse files
Add op support for zeros_like and fill_ (#17896)
* add op support for zeros_like and fill_ * fixing whitespace issues * unity issue * solved datatype issue * unity issue * lint error
1 parent 47b95ca commit ed477b0

File tree

5 files changed

+82
-9
lines changed

5 files changed

+82
-9
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1457,6 +1457,15 @@ def _fill(self, node: fx.Node) -> relax.Var:
14571457
value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype)
14581458
return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype))
14591459

1460+
def _inplace_fill(self, node: fx.Node) -> relax.Var:
1461+
args = self.retrieve_args(node)
1462+
x = args[0]
1463+
dtype = x.struct_info.dtype
1464+
value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype)
1465+
filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype))
1466+
self.env[node.args[0]] = filled
1467+
return filled
1468+
14601469
def _full(self, node: fx.Node) -> relax.Var:
14611470
import torch
14621471

@@ -1670,6 +1679,10 @@ def _zeros_inplace(self, node: fx.Node) -> relax.Var:
16701679
self.env[node.args[0]] = output
16711680
return output
16721681

1682+
def _zeros_like(self, node: fx.node) -> relax.Var:
1683+
x = self.env[node.args[0]]
1684+
return self.block_builder.emit(relax.op.zeros_like(x))
1685+
16731686
@abc.abstractmethod
16741687
def create_convert_map(
16751688
self,

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ def create_convert_map(
474474
"eye.default": self._eye,
475475
"eye.m": self._eye,
476476
"fill.Scalar": self._fill,
477+
"fill_.Scalar": self._inplace_fill,
477478
"full.default": self._full,
478479
"full_like.default": self._full_like,
479480
"index_select.default": self._index_select,
@@ -488,6 +489,7 @@ def create_convert_map(
488489
),
489490
"zero_.default": self._zeros_inplace,
490491
"zeros.default": self._zeros,
492+
"zeros_like.default": self._zeros_like,
491493
# datatype
492494
"to.dtype": self._to,
493495
"to.dtype_layout": self._to,

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -515,15 +515,6 @@ def _size(self, node: fx.Node) -> relax.Expr:
515515

516516
########## Creation ##########
517517

518-
def _inplace_fill(self, node: fx.Node) -> relax.Var:
519-
args = self.retrieve_args(node)
520-
x = args[0]
521-
dtype = x.struct_info.dtype
522-
value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype)
523-
filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype))
524-
self.env[node.args[0]] = filled
525-
return filled
526-
527518
def _inplace_copy(self, node: fx.Node) -> relax.Var:
528519
src = self.env[node.args[1]]
529520
self.env[node.args[0]] = src
@@ -830,6 +821,7 @@ def create_convert_map(
830821
"clone": lambda node: self.env[node.args[0]],
831822
"empty": self._empty,
832823
"empty_like": self._empty_like,
824+
"fill": self._fill,
833825
"fill_": self._inplace_fill,
834826
"full": self._full,
835827
"index_select": self._index_select,
@@ -844,6 +836,7 @@ def create_convert_map(
844836
),
845837
"tensor": self._tensor,
846838
"zero_": self._zeros_inplace,
839+
"zeros_like": self._zeros_like,
847840
"copy_": self._inplace_copy,
848841
# datatype
849842
"astype": self._type,

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3679,6 +3679,30 @@ def main(
36793679
verify_model(Fill(), example_args, {}, Expected)
36803680

36813681

3682+
def test_fill_inplace():
3683+
class FillInplace(Module):
3684+
def forward(self, input: torch.Tensor):
3685+
input.fill_(42.0)
3686+
return input
3687+
3688+
@tvm.script.ir_module
3689+
class Expected:
3690+
@R.function
3691+
def main(
3692+
x: R.Tensor((2, 3), dtype="float32")
3693+
) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
3694+
with R.dataflow():
3695+
lv: R.Tensor((2, 3), dtype="float32") = R.full(
3696+
R.shape([2, 3]), R.const(42.0, "float32"), dtype="float32"
3697+
)
3698+
gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
3699+
R.output(gv)
3700+
return gv
3701+
3702+
example_args = (torch.randn(2, 3, dtype=torch.float32),)
3703+
verify_model(FillInplace(), example_args, {}, Expected)
3704+
3705+
36823706
def test_masked_fill():
36833707
class Masked_Fill(Module):
36843708
def forward(self, input: torch.Tensor, mask: torch.Tensor):
@@ -4046,6 +4070,27 @@ def main(
40464070
verify_model(Zeros(), example_args, {}, Expected)
40474071

40484072

4073+
def test_zeros_like():
4074+
class ZerosLike(Module):
4075+
def forward(self, input):
4076+
return torch.zeros_like(input)
4077+
4078+
@tvm.script.ir_module
4079+
class Expected:
4080+
@R.function
4081+
def main(
4082+
input: R.Tensor((128, 128), dtype="float32")
4083+
) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
4084+
with R.dataflow():
4085+
lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like(input, dtype="void")
4086+
gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
4087+
R.output(gv)
4088+
return gv
4089+
4090+
example_args = (torch.rand(128, 128, dtype=torch.float32),)
4091+
verify_model(ZerosLike(), example_args, {}, Expected)
4092+
4093+
40494094
def test_type_as():
40504095
class TypeAs(Module):
40514096
def forward(self, input, other):

tests/python/relax/test_frontend_from_fx.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4747,6 +4747,26 @@ def main(
47474747
verify_model(ZeroInplace(), [([128, 128], "float32")], {}, Expected)
47484748

47494749

4750+
def test_zeros_like():
4751+
class ZerosLike(Module):
4752+
def forward(self, data):
4753+
return torch.zeros_like(data)
4754+
4755+
@tvm.script.ir_module
4756+
class Expected:
4757+
@R.function
4758+
def main(
4759+
inp_0: R.Tensor((128, 128), dtype="float32")
4760+
) -> R.Tensor((128, 128), dtype="float32"):
4761+
with R.dataflow():
4762+
lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like(inp_0, dtype="void")
4763+
gv: R.Tensor((128, 128), dtype="float32") = lv
4764+
R.output(gv)
4765+
return gv
4766+
4767+
verify_model(ZerosLike(), [([128, 128], "float32")], {}, Expected)
4768+
4769+
47504770
def test_type_as():
47514771
class TypeAs(Module):
47524772
def forward(self, data, other):

0 commit comments

Comments
 (0)