From 156040ae03ae231ecc03cfa9cb20b55bbee99b1e Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 23 Jun 2025 15:17:46 +0800 Subject: [PATCH] add type promotion logic --- tester/base.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tester/base.py b/tester/base.py index b6933f4b..3dfe5a1b 100644 --- a/tester/base.py +++ b/tester/base.py @@ -487,6 +487,32 @@ def gen_paddle_input(self): elif "driver" in self.paddle_kwargs: self.paddle_kwargs["driver"] = "gels" + if self.api_config.api_name == "paddle.concat": + # handle type promotion here + input_list = self.paddle_kwargs.get("x", None) + is_x_kwargs = input_list is not None + input_list = self.paddle_args[0] if not is_x_kwargs else input_list + is_tuple = isinstance(input_list, tuple) + if is_tuple: + input_list = list(input_list) + promoted_type = numpy.promote_types(str(input_list[0].dtype).split('.')[-1], str(input_list[1].dtype).split('.')[-1]) + num_inputs = len(input_list) + if num_inputs > 2: + for i in range(2, num_inputs): + promoted_type = numpy.promote_types(str(input_list[i].dtype).split('.')[-1], promoted_type) + for i in range(num_inputs): + input_list[i] = input_list[i].astype(str(promoted_type)) + if is_tuple: + if is_x_kwargs: + self.paddle_kwargs["x"] = tuple(input_list) + else: + self.paddle_args[0] = tuple(input_list) + else: + if is_x_kwargs: + self.paddle_kwargs["x"] = input_list + else: + self.paddle_args[0] = input_list + if self.need_check_grad(): if (self.api_config.api_name[-1] == "_" and self.api_config.api_name[-2:] != "__") or self.api_config.api_name == "paddle.Tensor.__setitem__": self.paddle_args, self.paddle_kwargs = self.copy_paddle_input() @@ -842,6 +868,25 @@ def gen_torch_input(self): else: self.torch_kwargs[key] = arg_config + if self.api_config.api_name == "paddle.concat": + # handle type promotion here + input_list = self.torch_kwargs["x"] + is_tuple = isinstance(input_list, tuple) + if is_tuple: + input_list = list(input_list) + input_list = list(input_list) + promoted_type = numpy.promote_types(str(input_list[0].dtype).split('.')[-1], str(input_list[1].dtype).split('.')[-1]) + num_inputs = len(input_list) + if num_inputs > 2: + for i in range(2, num_inputs): + promoted_type = numpy.promote_types(str(input_list[i].dtype).split('.')[-1], promoted_type) + for i in range(num_inputs): + input_list[i] = input_list[i].type(self.convert_dtype_to_torch_type(promoted_type)) + if is_tuple: + self.torch_kwargs["x"] = tuple(input_list) + else: + self.torch_kwargs["x"] = input_list + if self.need_check_grad(): if (self.api_config.api_name[-1] == "_" and self.api_config.api_name[-2:] != "__") or self.api_config.api_name == "paddle.Tensor.__setitem__": self.torch_args, self.torch_kwargs = self.copy_torch_input()