diff --git a/tests/base/datasets.py b/tests/base/datasets.py index b1991de6aa9e6..e55ef3484f7ea 100644 --- a/tests/base/datasets.py +++ b/tests/base/datasets.py @@ -63,13 +63,7 @@ def __init__(self, root: str = PATH_DATASETS, train: bool = True, data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME # FIXME: try to fix loading - for _ in range(30): - try: - self.data, self.targets = torch.load(os.path.join(self.cached_folder_path, data_file)) - except Exception: - time.sleep(1) - else: - break + self.data, self.targets = _try_load(os.path.join(self.cached_folder_path, data_file)) def __getitem__(self, idx: int) -> Tuple[Tensor, int]: img = self.data[idx].float().unsqueeze(0) @@ -111,6 +105,19 @@ def _download(self, data_folder: str) -> None: urllib.request.urlretrieve(url, fpath) +def _try_load(path_data, trials=30): + res = None + assert os.path.isfile(path_data) + for _ in range(trials): + try: + res = torch.load(path_data) + except Exception: + time.sleep(1) + else: + break + return res + + def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor: tensor = tensor.clone() mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device) @@ -195,7 +202,7 @@ def prepare_data(self, download: bool) -> None: for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME): path_fname = os.path.join(super().cached_folder_path, fname) assert os.path.isfile(path_fname), 'Missing cached file: %s' % path_fname - data, targets = torch.load(path_fname) + data, targets = _try_load(path_fname) data, targets = self._prepare_subset(data, targets, self.num_samples, self.digits) torch.save((data, targets), os.path.join(self.cached_folder_path, fname))