Skip to content

Commit 1437d5c

Browse files
[Relax][Pytorch] Add support for ones_like, zero_, zeros, type_as, item ops (#17868)
* Add support for ones_like,zero_,zeros,type_as,item * Fix lint issues * Fix lint issues * Removed unused import
1 parent 299ef81 commit 1437d5c

File tree

5 files changed

+237
-0
lines changed

5 files changed

+237
-0
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,6 +1501,12 @@ def _to(self, node: fx.Node) -> relax.Var:
15011501
return self.block_builder.emit(relax.op.astype(x, dtype))
15021502
return x
15031503

1504+
def _type_as(self, node: fx.Node) -> relax.Var:
1505+
x = self.env[node.args[0]]
1506+
other = self.env[node.args[1]]
1507+
dtype = other.struct_info.dtype
1508+
return self.block_builder.emit(relax.op.astype(x, dtype))
1509+
15041510
########## Others ##########
15051511

15061512
def _getitem(self, node: fx.Node) -> relax.Var:
@@ -1584,6 +1590,16 @@ def _getitem(self, node: fx.Node) -> relax.Var:
15841590
else:
15851591
assert False
15861592

1593+
def _item(self, node: fx.Node) -> relax.Var:
1594+
x = self.env[node.args[0]]
1595+
return self.block_builder.emit(relax.op.take(x, relax.const(0, "int64"), axis=0))
1596+
1597+
def _zeros_inplace(self, node: fx.Node) -> relax.Var:
1598+
x = self.env[node.args[0]]
1599+
output = self.block_builder.emit(relax.op.zeros_like(x))
1600+
self.env[node.args[0]] = output
1601+
return output
1602+
15871603
@abc.abstractmethod
15881604
def create_convert_map(
15891605
self,

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,14 @@ def _one_hot(self, node: fx.Node) -> relax.Var:
253253

254254
return self.block_builder.emit(relax.op.one_hot(x, on_value, off_value, num_classes, axis))
255255

256+
def _zeros(self, node: fx.Node) -> relax.Var:
257+
args = self.retrieve_args(node)
258+
size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],))
259+
dtype = self._convert_data_type(
260+
node.kwargs.get("dtype", torch.get_default_dtype()), self.env
261+
)
262+
return self.block_builder.emit(relax.op.zeros(size, dtype))
263+
256264
########## Others ##########
257265

258266
def create_convert_map(
@@ -470,11 +478,18 @@ def create_convert_map(
470478
"new_ones.default": self._new_ones,
471479
"one_hot.default": self._one_hot,
472480
"ones.default": self._ones,
481+
"ones_like.default": lambda node: self.block_builder.emit(
482+
relax.op.ones_like(self.env[node.args[0]])
483+
),
484+
"zero_.default": self._zeros_inplace,
485+
"zeros.default": self._zeros,
473486
# datatype
474487
"to.dtype": self._to,
475488
"to.dtype_layout": self._to,
489+
"type_as.default": self._type_as,
476490
# other
477491
"getitem": self._getitem,
492+
"item.default": self._item,
478493
}
479494

480495
def create_input_vars(

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,11 @@ def create_convert_map(
836836
"new_ones": self._new_ones,
837837
"ones": self._ones,
838838
"one_hot": self._one_hot,
839+
"ones_like": lambda node: self.block_builder.emit(
840+
relax.op.ones_like(self.env[node.args[0]])
841+
),
839842
"tensor": self._tensor,
843+
"zero_": self._zeros_inplace,
840844
"copy_": self._inplace_copy,
841845
# datatype
842846
"astype": self._type,
@@ -845,10 +849,12 @@ def create_convert_map(
845849
"is_floating_point": self._is_floating_point,
846850
"to": self._to,
847851
"type": self._type,
852+
"type_as": self._type_as,
848853
# other
849854
"getattr": self._getattr,
850855
"getitem": self._getitem,
851856
"sym_size.int": self._sym_size_int,
857+
"item": self._item,
852858
}
853859

854860
def update_convert_map(self, custom_convert_map: dict):

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3948,6 +3948,98 @@ def main(
39483948
verify_model(OneHot(), example_args, {}, Expected)
39493949

39503950

3951+
def test_ones_like():
3952+
class OnesLike(Module):
3953+
def forward(self, input):
3954+
return torch.ones_like(input)
3955+
3956+
@tvm.script.ir_module
3957+
class Expected:
3958+
@R.function
3959+
def main(
3960+
input: R.Tensor((128, 128), dtype="float32")
3961+
) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
3962+
with R.dataflow():
3963+
lv: R.Tensor((128, 128), dtype="float32") = R.ones_like(input, dtype="void")
3964+
gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
3965+
R.output(gv)
3966+
return gv
3967+
3968+
example_args = (torch.rand(128, 128, dtype=torch.float32),)
3969+
3970+
verify_model(OnesLike(), example_args, {}, Expected)
3971+
3972+
3973+
def test_zero_inplace():
3974+
class ZeroInplace(Module):
3975+
def forward(self, input):
3976+
return input.zero_()
3977+
3978+
@tvm.script.ir_module
3979+
class Expected:
3980+
@R.function
3981+
def main(
3982+
input: R.Tensor((128, 128), dtype="float32")
3983+
) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
3984+
with R.dataflow():
3985+
lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like(input, dtype="void")
3986+
gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
3987+
R.output(gv)
3988+
return gv
3989+
3990+
example_args = (torch.rand(128, 128, dtype=torch.float32),)
3991+
3992+
verify_model(ZeroInplace(), example_args, {}, Expected)
3993+
3994+
3995+
def test_zeros():
3996+
class Zeros(Module):
3997+
def forward(self, input):
3998+
return torch.zeros(5, 2)
3999+
4000+
@tvm.script.ir_module
4001+
class Expected:
4002+
@R.function
4003+
def main(
4004+
input: R.Tensor((128, 128), dtype="float32")
4005+
) -> R.Tuple(R.Tensor((5, 2), dtype="float32")):
4006+
with R.dataflow():
4007+
lv: R.Tensor((5, 2), dtype="float32") = R.zeros(R.shape([5, 2]), dtype="float32")
4008+
gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,)
4009+
R.output(gv)
4010+
return gv
4011+
4012+
example_args = (torch.rand(128, 128, dtype=torch.float32),)
4013+
4014+
verify_model(Zeros(), example_args, {}, Expected)
4015+
4016+
4017+
def test_type_as():
4018+
class TypeAs(Module):
4019+
def forward(self, input, other):
4020+
return input.type_as(other)
4021+
4022+
@tvm.script.ir_module
4023+
class Expected:
4024+
@R.function
4025+
def main(
4026+
input: R.Tensor((128, 128), dtype="float32"),
4027+
other: R.Tensor((128, 128), dtype="float16"),
4028+
) -> R.Tuple(R.Tensor((128, 128), dtype="float16")):
4029+
with R.dataflow():
4030+
lv: R.Tensor((128, 128), dtype="float16") = R.astype(input, dtype="float16")
4031+
gv: R.Tuple(R.Tensor((128, 128), dtype="float16")) = (lv,)
4032+
R.output(gv)
4033+
return gv
4034+
4035+
example_args = (
4036+
torch.rand(128, 128, dtype=torch.float32),
4037+
torch.rand(128, 128, dtype=torch.float16),
4038+
)
4039+
4040+
verify_model(TypeAs(), example_args, {}, Expected)
4041+
4042+
39514043
def test_select():
39524044
class Select(Module):
39534045
def forward(self, input):
@@ -4379,6 +4471,25 @@ def main(
43794471
verify_model(Narrow(), example_args, {}, Expected)
43804472

43814473

4474+
def test_item():
4475+
class Item(Module):
4476+
def forward(self, x):
4477+
return x.item()
4478+
4479+
@tvm.script.ir_module
4480+
class Expected:
4481+
@R.function
4482+
def main(input: R.Tensor((1,), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")):
4483+
with R.dataflow():
4484+
lv: R.Tensor((), dtype="float32") = R.take(input, R.const(0, "int64"), axis=0)
4485+
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
4486+
R.output(gv)
4487+
return gv
4488+
4489+
example_args = (torch.randn(1, dtype=torch.float32),)
4490+
verify_model(Item(), example_args, {}, Expected)
4491+
4492+
43824493
def test_norm():
43834494
class Norm(Module):
43844495
def __init__(self, p, dim=None, keepdim=False):

tests/python/relax/test_frontend_from_fx.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4506,6 +4506,95 @@ def main(
45064506
verify_model(EmptyLike(), [([5], "float32")], {}, Expected)
45074507

45084508

4509+
def test_ones_like():
4510+
class OnesLike(Module):
4511+
def forward(self, data):
4512+
return torch.ones_like(data)
4513+
4514+
@tvm.script.ir_module
4515+
class Expected:
4516+
@R.function
4517+
def main(
4518+
inp_0: R.Tensor((128, 128), dtype="float32")
4519+
) -> R.Tensor((128, 128), dtype="float32"):
4520+
with R.dataflow():
4521+
lv: R.Tensor((128, 128), dtype="float32") = R.ones_like(inp_0, dtype="void")
4522+
gv: R.Tensor((128, 128), dtype="float32") = lv
4523+
R.output(gv)
4524+
return gv
4525+
4526+
verify_model(OnesLike(), [([128, 128], "float32")], {}, Expected)
4527+
4528+
4529+
def test_zero_inplace():
4530+
class ZeroInplace(Module):
4531+
def forward(self, data):
4532+
return data.zero_()
4533+
4534+
@tvm.script.ir_module
4535+
class Expected:
4536+
@R.function
4537+
def main(
4538+
inp_0: R.Tensor((128, 128), dtype="float32")
4539+
) -> R.Tensor((128, 128), dtype="float32"):
4540+
with R.dataflow():
4541+
lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like(inp_0, dtype="void")
4542+
gv: R.Tensor((128, 128), dtype="float32") = lv
4543+
R.output(gv)
4544+
return gv
4545+
4546+
verify_model(ZeroInplace(), [([128, 128], "float32")], {}, Expected)
4547+
4548+
4549+
def test_type_as():
4550+
class TypeAs(Module):
4551+
def forward(self, data, other):
4552+
return data.type_as(other)
4553+
4554+
@tvm.script.ir_module
4555+
class Expected:
4556+
@R.function
4557+
def main(
4558+
inp_0: R.Tensor((128, 128), dtype="float16"),
4559+
inp_1: R.Tensor((128, 128), dtype="float32"),
4560+
) -> R.Tensor((128, 128), dtype="float32"):
4561+
with R.dataflow():
4562+
lv: R.Tensor((128, 128), dtype="float32") = R.astype(inp_0, dtype="float32")
4563+
gv: R.Tensor((128, 128), dtype="float32") = lv
4564+
R.output(gv)
4565+
return gv
4566+
4567+
verify_model(TypeAs(), [([128, 128], "float16"), ([128, 128], "float32")], {}, Expected)
4568+
4569+
4570+
def test_item():
4571+
class Item(Module):
4572+
def forward(self, data):
4573+
return data.item()
4574+
4575+
@tvm.script.ir_module
4576+
class Expected:
4577+
@R.function
4578+
def main(inp_0: R.Tensor((1,), dtype="float32")) -> R.Tensor((), dtype="float32"):
4579+
with R.dataflow():
4580+
lv: R.Tensor((), dtype="float32") = R.take(inp_0, R.const(0, "int64"), axis=0)
4581+
gv: R.Tensor((), dtype="float32") = lv
4582+
R.output(gv)
4583+
return gv
4584+
4585+
verify_model(
4586+
Item(),
4587+
[
4588+
(
4589+
[1],
4590+
"float32",
4591+
)
4592+
],
4593+
{},
4594+
Expected,
4595+
)
4596+
4597+
45094598
def test_numel():
45104599
class Numel(Module):
45114600
def forward(self, data):

0 commit comments

Comments
 (0)