diff --git a/test/ir/inference/program_config.py b/test/ir/inference/program_config.py index a2f36617de8e6..f9424502484cc 100644 --- a/test/ir/inference/program_config.py +++ b/test/ir/inference/program_config.py @@ -275,6 +275,7 @@ def generate_weight(): self.outputs = outputs self.input_type = input_type self.no_cast_list = [] if no_cast_list is None else no_cast_list + self.supported_cast_type = [np.float32, np.float16] def __repr__(self): log_str = '' @@ -292,11 +293,9 @@ def __repr__(self): return log_str def set_input_type(self, _type: np.dtype) -> None: - assert _type in [ - np.float32, - np.float16, - None, - ], "PaddleTRT only supports FP32 / FP16 IO" + assert ( + _type in self.supported_cast_type or _type is None + ), "PaddleTRT only supports FP32 / FP16 IO" ver = paddle.inference.get_trt_compile_version() trt_version = ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 @@ -309,15 +308,14 @@ def set_input_type(self, _type: np.dtype) -> None: def get_feed_data(self) -> Dict[str, Dict[str, Any]]: feed_data = {} for name, tensor_config in self.inputs.items(): - do_casting = ( - self.input_type is not None and name not in self.no_cast_list - ) + data = tensor_config.data # Cast to target input_type - data = ( - tensor_config.data.astype(self.input_type) - if do_casting - else tensor_config.data - ) + if ( + self.input_type is not None + and name not in self.no_cast_list + and data.dtype in self.supported_cast_type + ): + data = data.astype(self.input_type) # Truncate FP32 tensors to FP16 precision for FP16 test stability if data.dtype == np.float32 and name not in self.no_cast_list: data = data.astype(np.float16).astype(np.float32) @@ -334,10 +332,14 @@ def _cast(self) -> None: for name, inp in self.inputs.items(): if name in self.no_cast_list: continue + if inp.dtype not in self.supported_cast_type: + continue inp.convert_type_inplace(self.input_type) for name, weight in self.weights.items(): if name in self.no_cast_list: continue + if weight.dtype not in self.supported_cast_type: + continue weight.convert_type_inplace(self.input_type) return self diff --git a/test/ir/inference/test_trt_convert_assign.py b/test/ir/inference/test_trt_convert_assign.py index 55939982d5ee0..99b027877bc9c 100644 --- a/test/ir/inference/test_trt_convert_assign.py +++ b/test/ir/inference/test_trt_convert_assign.py @@ -120,9 +120,8 @@ def clear_dynamic_shape(): self.dynamic_shape.opt_input_shape = {} def generate_trt_nodes_num(attrs, dynamic_shape): - if not dynamic_shape and ( - self.has_bool_dtype or self.dims == 1 or self.dims == 0 - ): + # Static shape does not support 0 or 1 dim's input + if not dynamic_shape and (self.dims == 1 or self.dims == 0): return 0, 4 return 1, 2 diff --git a/test/ir/inference/test_trt_convert_cast.py b/test/ir/inference/test_trt_convert_cast.py index 026abc571050a..0b5f2186429e3 100644 --- a/test/ir/inference/test_trt_convert_cast.py +++ b/test/ir/inference/test_trt_convert_cast.py @@ -118,6 +118,7 @@ def generate_input(type): ) }, outputs=["cast_output_data1"], + no_cast_list=["input_data"], ) yield program_config