diff --git a/shimmer/data/dataset.py b/shimmer/data/dataset.py index 34136534..8776cccc 100644 --- a/shimmer/data/dataset.py +++ b/shimmer/data/dataset.py @@ -55,7 +55,7 @@ def __init__( dataset_path: str | Path, split: str, domain_classes: Mapping[DomainDesc, type[DataDomain]], - max_size: int = -1, + max_size: int | None = None, transforms: Mapping[str, Callable[[Any], Any]] | None = None, domain_args: Mapping[str, Any] | None = None, ): @@ -94,7 +94,7 @@ def __init__( lengths = {len(domain) for domain in self.domains.values()} assert len(lengths) == 1, "Domains have different lengths" self.dataset_size = next(iter(lengths)) - if self.max_size != -1: + if self.max_size is not None: assert ( self.max_size <= self.dataset_size ), "Max sizes can only be lower than actual size." diff --git a/shimmer/data/domain.py b/shimmer/data/domain.py index 88fc758c..a76d0e0d 100644 --- a/shimmer/data/domain.py +++ b/shimmer/data/domain.py @@ -23,6 +23,7 @@ def __init__( self, dataset_path: str | Path, split: str, + max_size: int | None = None, transform: Callable[[Any], _T] | None = None, additional_args: dict[str, Any] | None = None, ) -> None: