From b92a3e4bdbf26e47515f5f90200f60996cf3e9ff Mon Sep 17 00:00:00 2001 From: Danijar Hafner Date: Tue, 15 Oct 2024 02:53:02 +0000 Subject: [PATCH] Allow range reads on non-list dataset keys --- granular/__init__.py | 2 +- granular/dataset.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/granular/__init__.py b/granular/__init__.py index 327ae5a..64d6357 100644 --- a/granular/__init__.py +++ b/granular/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.20.0' +__version__ = '0.20.2' from .bag import BagWriter from .bag import BagReader diff --git a/granular/dataset.py b/granular/dataset.py index 0dc14b9..9a24e07 100644 --- a/granular/dataset.py +++ b/granular/dataset.py @@ -190,13 +190,14 @@ def __getitem__(self, index): assert 0 <= index.start <= index.stop and index.step in (None, 1), index refs = self._getrefs(index.start, index.stop) for i, (key, dtype) in enumerate(self.spec.items()): + ref, msk = [x[i] for x in refs], mask.get(key, False) # Cannot range read datapoints that contain sequence modalities, # because they may not be consecutive and thus could be slow. - assert not dtype.endswith('[]'), (index, key, dtype) - ref, msk = [x[i] for x in refs], mask.get(key, False) assert isinstance(msk, bool), (key, msk, type(msk)) - if msk: - needed[key] = range(ref[0], ref[-1] + 1) + assert not (msk and dtype.endswith('[]')), (index, key, dtype) + if not msk: + continue + needed[key] = range(ref[0], ref[-1] + 1) points = self._fetch(needed) decoded = { k: [self._decode(k, v, self.spec[k]) for v in vs]