diff --git a/nonechucks/__init__.py b/nonechucks/__init__.py index 054922b..e293f43 100644 --- a/nonechucks/__init__.py +++ b/nonechucks/__init__.py @@ -41,3 +41,4 @@ def _get_pytorch_version(): from nonechucks.dataset import SafeDataset from nonechucks.sampler import SafeSampler from nonechucks.dataloader import SafeDataLoader +from nonechucks.utils import NoneChucksSkipException diff --git a/nonechucks/dataset.py b/nonechucks/dataset.py index 342e6d3..c3bb462 100644 --- a/nonechucks/dataset.py +++ b/nonechucks/dataset.py @@ -2,6 +2,7 @@ import torch.utils.data from nonechucks.utils import memoize +from nonechucks.utils import NoneChucksSkipException class SafeDataset(torch.utils.data.Dataset): @@ -9,10 +10,11 @@ class SafeDataset(torch.utils.data.Dataset): samples dynamically. """ - def __init__(self, dataset, eager_eval=False): + def __init__(self, dataset, eager_eval=False, catch_all = True): """Creates a `SafeDataset` wrapper around `dataset`.""" self.dataset = dataset self.eager_eval = eager_eval + self.catch_all = catch_all # These will contain indices over the original dataset. The indices of # the safe samples will go into _safe_indices and similarly for unsafe # samples. @@ -41,6 +43,9 @@ def _safe_get_item(self, idx): self._safe_indices.append(idx) return sample except Exception as e: + if not self.catch_all: + if not isinstance(e, NoneChucksSkipException): + raise if isinstance(e, IndexError): if invalid_idx: raise diff --git a/nonechucks/utils.py b/nonechucks/utils.py index 9ad08a6..dae98a0 100644 --- a/nonechucks/utils.py +++ b/nonechucks/utils.py @@ -12,6 +12,10 @@ from torch._six import string_classes +class NoneChucksSkipException(Exception): + None # ...Chucks + + class memoize(object): """cache the return value of a method