Skip to content

Commit 666ded5

Browse files
authored
fix: max_size in DataDomain and ShimmerDataset (#180)
add max_size to DataDomain and changed max_size default value to None instead -1 in ShimmerDataset
1 parent 519d5f2 commit 666ded5

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

shimmer/data/dataset.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
dataset_path: str | Path,
5656
split: str,
5757
domain_classes: Mapping[DomainDesc, type[DataDomain]],
58-
max_size: int = -1,
58+
max_size: int | None = None,
5959
transforms: Mapping[str, Callable[[Any], Any]] | None = None,
6060
domain_args: Mapping[str, Any] | None = None,
6161
):
@@ -87,14 +87,15 @@ def __init__(
8787
self.domains[domain.kind] = domain_cls(
8888
dataset_path,
8989
split,
90+
max_size,
9091
transform,
9192
self.domain_args.get(domain.kind, None),
9293
)
9394

9495
lengths = {len(domain) for domain in self.domains.values()}
9596
assert len(lengths) == 1, "Domains have different lengths"
9697
self.dataset_size = next(iter(lengths))
97-
if self.max_size != -1:
98+
if self.max_size is not None:
9899
assert (
99100
self.max_size <= self.dataset_size
100101
), "Max sizes can only be lower than actual size."

shimmer/data/domain.py

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
self,
2424
dataset_path: str | Path,
2525
split: str,
26+
max_size: int | None = None,
2627
transform: Callable[[Any], _T] | None = None,
2728
additional_args: dict[str, Any] | None = None,
2829
) -> None:

0 commit comments

Comments
 (0)