Skip to content

Commit 162f8b8

Browse files
authored
revert: fix: max_size in DataDomain and ShimmerDataset (#180) (#182)
Revert the addition of max_size to `DataDomain`. Add warning if length of the different domains are different instead. This prevents all DataDomain to have to implement max_size
1 parent d94da27 commit 162f8b8

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

shimmer/data/dataset.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from collections.abc import Callable, Mapping
23
from pathlib import Path
34
from typing import Any
@@ -65,7 +66,7 @@ def __init__(
6566
split (str): Split to use. One of 'train', 'val', 'test'.
6667
domain_classes (Mapping[str, type[SimpleShapesDomain]]): Classes of
6768
domain loaders to include in the dataset.
68-
max_size (int): Max size of the dataset.
69+
max_size (int | None): Max size of the dataset.
6970
transforms (Mapping[str, (Any) -> Any]): Optional transforms to apply
7071
to the domains. The keys are the domain names,
7172
the values are the transforms.
@@ -87,14 +88,19 @@ def __init__(
8788
self.domains[domain.kind] = domain_cls(
8889
dataset_path,
8990
split,
90-
max_size,
9191
transform,
9292
self.domain_args.get(domain.kind, None),
9393
)
9494

9595
lengths = {len(domain) for domain in self.domains.values()}
96-
assert len(lengths) == 1, "Domains have different lengths"
97-
self.dataset_size = next(iter(lengths))
96+
min_length = min(lengths)
97+
if len(lengths) != 1:
98+
warnings.warn(
99+
f"Domains have different lengths. Selecting min ({min_length}).",
100+
UserWarning,
101+
stacklevel=2,
102+
)
103+
self.dataset_size = min_length
98104
if self.max_size is not None:
99105
assert (
100106
self.max_size <= self.dataset_size

shimmer/data/domain.py

-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def __init__(
2323
self,
2424
dataset_path: str | Path,
2525
split: str,
26-
max_size: int | None = None,
2726
transform: Callable[[Any], _T] | None = None,
2827
additional_args: dict[str, Any] | None = None,
2928
) -> None:

0 commit comments

Comments
 (0)