Skip to content

Commit cc8afdb

Browse files
authored
Add support for torch.nn.functional.max_pool2d (#17189)
* add a testcase for call_function * add maxpool2d to call_function
1 parent 9f0f301 commit cc8afdb

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,6 +1476,7 @@ def create_convert_map(self):
14761476
"getitem": self._getitem,
14771477
"contiguous": lambda node: self.env[node.args[0]],
14781478
"to": self._to,
1479+
"max_pool2d": self._max_pool2d,
14791480
"avg_pool2d": self._avg_pool2d,
14801481
"adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
14811482
"layer_norm": self._layer_norm,

tests/python/relax/test_frontend_from_fx.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,13 @@ def __init__(self):
796796
def forward(self, input):
797797
return self.pool(input)
798798

799+
class MaxPool2d_functional(Module):
800+
def __init__(self):
801+
super().__init__()
802+
803+
def forward(self, input):
804+
return torch.nn.functional.max_pool2d(input, kernel_size=[1, 1])
805+
799806
@tvm.script.ir_module
800807
class expected1:
801808
@R.function
@@ -876,6 +883,7 @@ def main(
876883
return gv
877884

878885
verify_model(MaxPool2d(), input_info, {}, expected1)
886+
verify_model(MaxPool2d_functional(), input_info, {}, expected1)
879887
verify_model(MaxPool2d2(), input_info, {}, expected2)
880888
verify_model(MaxPool2d3(), input_info, {}, expected3)
881889

0 commit comments

Comments
 (0)