Skip to content

Commit

Permalink
Supprot npy for pytest (#1164)
Browse files Browse the repository at this point in the history
* update

* Apply code-format changes

---------

Co-authored-by: FusionBolt <FusionBolt@users.noreply.github.com>
FusionBolt and FusionBolt authored Feb 4, 2024
1 parent e7d0aa1 commit a562682
Showing 2 changed files with 12 additions and 3 deletions.
3 changes: 3 additions & 0 deletions tests/generator.py
Original file line number Diff line number Diff line change
@@ -112,3 +112,6 @@ def from_image(self, shape: List[int], dtype: np.dtype, img_file: str) -> np.nda

def from_constant_of_shape(self, shape: List[int], dtype: np.dtype) -> np.ndarray:
return np.array(shape, dtype=dtype)

def from_numpy(self, path) -> np.ndarray:
return np.load(path)
12 changes: 9 additions & 3 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
@@ -400,6 +400,12 @@ def generate_data(self, name: str, inputs: List[Dict], compile_opt, generator_cf
file_list.extend([os.path.join(args, p) for p in os.listdir(args)])
elif method == 'constant_of_shape':
assert len(args) != 0
elif method == 'numpy':
assert(os.path.isdir(args))
for file in os.listdir(args):
if file.endswith('.npy'):
file_list.append(os.path.join(args, file))
file_list.sort()
else:
assert '{0} : not supported generator method'.format(method)

@@ -420,19 +426,19 @@ def generate_data(self, name: str, inputs: List[Dict], compile_opt, generator_cf
input_shape[0] *= generator_cfg['batch']

for batch_idx in range(batch_number):
idx = input_idx * batch_number + batch_idx
if method == 'random':
data = generator.from_random(input_shape, dtype, args)
elif method == 'bin':
idx = input_idx * batch_number + batch_idx
assert(idx < len(file_list))
data = generator.from_bin(input_shape, dtype, file_list[idx])
elif method == 'image':
idx = input_idx * batch_number + batch_idx
assert(idx < len(file_list))
data = generator.from_image(input_shape, dtype, file_list[idx])
elif method == 'constant_of_shape':
data = generator.from_constant_of_shape(args, dtype)

elif method == 'numpy':
data = generator.from_numpy(file_list[idx])
if not test_utils.in_ci():
dump_bin_file(os.path.join(self.case_dir, name,
f'{name}_{input_idx}_{batch_idx}.bin'), data)

0 comments on commit a562682

Please sign in to comment.