From a56268217e53ecda711152bf015766598b5368c3 Mon Sep 17 00:00:00 2001 From: FusionBolt <59008347+FusionBolt@users.noreply.github.com> Date: Sun, 4 Feb 2024 10:48:52 +0800 Subject: [PATCH] Supprot npy for pytest (#1164) * update * Apply code-format changes --------- Co-authored-by: FusionBolt --- tests/generator.py | 3 +++ tests/test_runner.py | 12 +++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/generator.py b/tests/generator.py index 20d7bcd0a2..a59c611dc5 100644 --- a/tests/generator.py +++ b/tests/generator.py @@ -112,3 +112,6 @@ def from_image(self, shape: List[int], dtype: np.dtype, img_file: str) -> np.nda def from_constant_of_shape(self, shape: List[int], dtype: np.dtype) -> np.ndarray: return np.array(shape, dtype=dtype) + + def from_numpy(self, path) -> np.ndarray: + return np.load(path) diff --git a/tests/test_runner.py b/tests/test_runner.py index 58906311d2..0b58ba69a4 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -400,6 +400,12 @@ def generate_data(self, name: str, inputs: List[Dict], compile_opt, generator_cf file_list.extend([os.path.join(args, p) for p in os.listdir(args)]) elif method == 'constant_of_shape': assert len(args) != 0 + elif method == 'numpy': + assert(os.path.isdir(args)) + for file in os.listdir(args): + if file.endswith('.npy'): + file_list.append(os.path.join(args, file)) + file_list.sort() else: assert '{0} : not supported generator method'.format(method) @@ -420,19 +426,19 @@ def generate_data(self, name: str, inputs: List[Dict], compile_opt, generator_cf input_shape[0] *= generator_cfg['batch'] for batch_idx in range(batch_number): + idx = input_idx * batch_number + batch_idx if method == 'random': data = generator.from_random(input_shape, dtype, args) elif method == 'bin': - idx = input_idx * batch_number + batch_idx assert(idx < len(file_list)) data = generator.from_bin(input_shape, dtype, file_list[idx]) elif method == 'image': - idx = input_idx * batch_number + batch_idx assert(idx < len(file_list)) data = generator.from_image(input_shape, dtype, file_list[idx]) elif method == 'constant_of_shape': data = generator.from_constant_of_shape(args, dtype) - + elif method == 'numpy': + data = generator.from_numpy(file_list[idx]) if not test_utils.in_ci(): dump_bin_file(os.path.join(self.case_dir, name, f'{name}_{input_idx}_{batch_idx}.bin'), data)