From 79c6d716e0ec0eb5ba45b6d464ba913d498b1b82 Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Wed, 15 Jan 2025 09:51:17 +0000 Subject: [PATCH] Revert "fix: set 500,000 as the default value for max_sizes..." This reverts commit d6ba09c0766c58cc67d4cc4062eb543692bceff7. --- simple_shapes_dataset/cli/alignments.py | 6 +++--- simple_shapes_dataset/cli/create_dataset.py | 4 ++-- simple_shapes_dataset/data_module.py | 2 +- simple_shapes_dataset/domain_alignment.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/simple_shapes_dataset/cli/alignments.py b/simple_shapes_dataset/cli/alignments.py index 73fd9fc..28155ba 100644 --- a/simple_shapes_dataset/cli/alignments.py +++ b/simple_shapes_dataset/cli/alignments.py @@ -18,7 +18,7 @@ def create_domain_split( seed: int, dataset_path: Path, domain_alignment: list[tuple[str, float]], - max_train_size: int | None = 500_000, + max_train_size: int | None = None, ): if not len(domain_alignment): return @@ -66,7 +66,7 @@ def create_domain_split( @click.option( "--max_train_size", "--ms", - default=500_000, + default=None, type=int, help="Max index to use for the train set.", ) @@ -83,7 +83,7 @@ def create_domain_split( def add_alignment_split( seed: int, dataset_path: str, - max_train_size: int, + max_train_size: int | None, domain_alignment: list[tuple[str, float]], ) -> None: dataset_location = Path(dataset_path) diff --git a/simple_shapes_dataset/cli/create_dataset.py b/simple_shapes_dataset/cli/create_dataset.py index c69abaf..366089c 100644 --- a/simple_shapes_dataset/cli/create_dataset.py +++ b/simple_shapes_dataset/cli/create_dataset.py @@ -107,7 +107,7 @@ def create_unpaired_attributes( @click.option( "--max_train_size", "--ms", - default=500_000, + default=None, type=int, help="Max index to use for the train set.", ) @@ -133,7 +133,7 @@ def create_dataset( min_lightness: int, max_lightness: int, bert_path: str, - max_train_size: int, + max_train_size: int | None, domain_alignment: list[tuple[str, float]], ) -> None: dataset_location = Path(output_path) diff --git a/simple_shapes_dataset/data_module.py b/simple_shapes_dataset/data_module.py index 9c01e35..61eaeae 100644 --- a/simple_shapes_dataset/data_module.py +++ b/simple_shapes_dataset/data_module.py @@ -33,7 +33,7 @@ def __init__( domain_classes: Mapping[DomainDesc, type[DataDomain]], domain_proportions: Mapping[frozenset[str], float], batch_size: int, - max_train_size: int | None = 500_000, + max_train_size: int | None = None, num_workers: int = 0, seed: int | None = None, ood_seed: int | None = None, diff --git a/simple_shapes_dataset/domain_alignment.py b/simple_shapes_dataset/domain_alignment.py index 0adf796..7a47386 100644 --- a/simple_shapes_dataset/domain_alignment.py +++ b/simple_shapes_dataset/domain_alignment.py @@ -15,7 +15,7 @@ def get_alignment( split: str, domain_proportions: Mapping[frozenset[str], float], seed: int, - max_size: int | None = 500_000, + max_size: int | None, ) -> Mapping[frozenset[str], np.ndarray]: assert split in ["train", "val", "test"] @@ -55,7 +55,7 @@ def get_aligned_datasets( domain_classes: Mapping[DomainDesc, type[DataDomain]], domain_proportions: Mapping[frozenset[str], float], seed: int, - max_size: int | None = 500_000, + max_size: int | None = None, transforms: Mapping[str, Callable[[Any], Any]] | None = None, domain_args: Mapping[str, Any] | None = None, ) -> dict[frozenset[str], Subset]: