From 46be5a8f5a6cbedd57cd812a164e5cd4899d849b Mon Sep 17 00:00:00 2001 From: Nana <49900969+NKNaN@users.noreply.github.com> Date: Mon, 23 Sep 2024 15:03:27 +0800 Subject: [PATCH] =?UTF-8?q?Unpool/Unpool3d=20kernel=20support=20input=20of?= =?UTF-8?q?=20int64=20`indices`=20=E6=98=93=E7=94=A8=E6=80=A7=E6=8F=90?= =?UTF-8?q?=E5=8D=87=20(#480)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit update torch.nn.functional.max_unpool --- paconvert/api_mapping.json | 6 +++--- paconvert/api_matcher.py | 7 ------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 2335aab1b..2791a311b 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -12287,7 +12287,7 @@ "min_input_args": 0 }, "torch.nn.functional.max_unpool1d": { - "Matcher": "UnpoolMatcher", + "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.functional.max_unpool1d", "args_list": [ "input", @@ -12303,7 +12303,7 @@ "min_input_args": 3 }, "torch.nn.functional.max_unpool2d": { - "Matcher": "UnpoolMatcher", + "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.functional.max_unpool2d", "args_list": [ "input", @@ -12319,7 +12319,7 @@ "min_input_args": 3 }, "torch.nn.functional.max_unpool3d": { - "Matcher": "UnpoolMatcher", + "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.functional.max_unpool3d", "args_list": [ "input", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index b1a901209..fa0dfb2be 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -3878,13 +3878,6 @@ def generate_code(self, kwargs): return code -class UnpoolMatcher(BaseMatcher): - def generate_code(self, kwargs): - kwargs["indices"] = "{}.astype('int32')".format(kwargs["indices"]) - - return GenericMatcher.generate_code(self, kwargs) - - class SoftmaxMatcher(BaseMatcher): def generate_aux_code(self): CODE_TEMPLATE = textwrap.dedent(