Skip to content

Commit

Permalink
Unpool/Unpool3d kernel support input of int64 indices 易用性提升 (#480)
Browse files Browse the repository at this point in the history
update torch.nn.functional.max_unpool
  • Loading branch information
NKNaN authored Sep 23, 2024
1 parent 09a5ec1 commit 46be5a8
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 10 deletions.
6 changes: 3 additions & 3 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -12287,7 +12287,7 @@
"min_input_args": 0
},
"torch.nn.functional.max_unpool1d": {
"Matcher": "UnpoolMatcher",
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.max_unpool1d",
"args_list": [
"input",
Expand All @@ -12303,7 +12303,7 @@
"min_input_args": 3
},
"torch.nn.functional.max_unpool2d": {
"Matcher": "UnpoolMatcher",
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.max_unpool2d",
"args_list": [
"input",
Expand All @@ -12319,7 +12319,7 @@
"min_input_args": 3
},
"torch.nn.functional.max_unpool3d": {
"Matcher": "UnpoolMatcher",
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.max_unpool3d",
"args_list": [
"input",
Expand Down
7 changes: 0 additions & 7 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3878,13 +3878,6 @@ def generate_code(self, kwargs):
return code


class UnpoolMatcher(BaseMatcher):
def generate_code(self, kwargs):
kwargs["indices"] = "{}.astype('int32')".format(kwargs["indices"])

return GenericMatcher.generate_code(self, kwargs)


class SoftmaxMatcher(BaseMatcher):
def generate_aux_code(self):
CODE_TEMPLATE = textwrap.dedent(
Expand Down

0 comments on commit 46be5a8

Please sign in to comment.