Skip to content

Commit

Permalink
dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Jul 23, 2020
1 parent 2a5495f commit f336507
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions tests/base/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,7 @@ def __init__(self, root: str = PATH_DATASETS, train: bool = True,

data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
# FIXME: try to fix loading
for _ in range(30):
try:
self.data, self.targets = torch.load(os.path.join(self.cached_folder_path, data_file))
except Exception:
time.sleep(1)
else:
break
self.data, self.targets = _try_load(os.path.join(self.cached_folder_path, data_file))

def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
img = self.data[idx].float().unsqueeze(0)
Expand Down Expand Up @@ -111,6 +105,19 @@ def _download(self, data_folder: str) -> None:
urllib.request.urlretrieve(url, fpath)


def _try_load(path_data, trials=30):
res = None
assert os.path.isfile(path_data)
for _ in range(trials):
try:
res = torch.load(path_data)
except Exception:
time.sleep(1)
else:
break
return res


def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor:
tensor = tensor.clone()
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
Expand Down Expand Up @@ -195,7 +202,7 @@ def prepare_data(self, download: bool) -> None:
for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME):
path_fname = os.path.join(super().cached_folder_path, fname)
assert os.path.isfile(path_fname), 'Missing cached file: %s' % path_fname
data, targets = torch.load(path_fname)
data, targets = _try_load(path_fname)
data, targets = self._prepare_subset(data, targets, self.num_samples, self.digits)
torch.save((data, targets), os.path.join(self.cached_folder_path, fname))

Expand Down

0 comments on commit f336507

Please sign in to comment.