Skip to content

Commit

Permalink
dataset: Add option to only catch NoneChucksSkipExceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
moshimeow committed Sep 22, 2022
1 parent 3736b3f commit 6ebe1cc
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion nonechucks/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
import torch.utils.data

from nonechucks.utils import memoize
from nonechucks.utils import NoneChucksSkipException


class SafeDataset(torch.utils.data.Dataset):
"""A wrapper around a torch.utils.data.Dataset that allows dropping
samples dynamically.
"""

def __init__(self, dataset, eager_eval=False):
def __init__(self, dataset, eager_eval=False, catch_all = True):
"""Creates a `SafeDataset` wrapper around `dataset`."""
self.dataset = dataset
self.eager_eval = eager_eval
self.catch_all = True
# These will contain indices over the original dataset. The indices of
# the safe samples will go into _safe_indices and similarly for unsafe
# samples.
Expand Down Expand Up @@ -41,6 +43,9 @@ def _safe_get_item(self, idx):
self._safe_indices.append(idx)
return sample
except Exception as e:
if not self.catch_all:
if not isinstance(e, NoneChucksSkipException):
raise
if isinstance(e, IndexError):
if invalid_idx:
raise
Expand Down

0 comments on commit 6ebe1cc

Please sign in to comment.