Skip to content

Commit

Permalink
Revert "fix: set 500,000 as the default value for max_sizes..."
Browse files Browse the repository at this point in the history
This reverts commit d6ba09c.
  • Loading branch information
bdvllrs committed Jan 15, 2025
1 parent d6ba09c commit 79c6d71
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
6 changes: 3 additions & 3 deletions simple_shapes_dataset/cli/alignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def create_domain_split(
seed: int,
dataset_path: Path,
domain_alignment: list[tuple[str, float]],
max_train_size: int | None = 500_000,
max_train_size: int | None = None,
):
if not len(domain_alignment):
return
Expand Down Expand Up @@ -66,7 +66,7 @@ def create_domain_split(
@click.option(
"--max_train_size",
"--ms",
default=500_000,
default=None,
type=int,
help="Max index to use for the train set.",
)
Expand All @@ -83,7 +83,7 @@ def create_domain_split(
def add_alignment_split(
seed: int,
dataset_path: str,
max_train_size: int,
max_train_size: int | None,
domain_alignment: list[tuple[str, float]],
) -> None:
dataset_location = Path(dataset_path)
Expand Down
4 changes: 2 additions & 2 deletions simple_shapes_dataset/cli/create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def create_unpaired_attributes(
@click.option(
"--max_train_size",
"--ms",
default=500_000,
default=None,
type=int,
help="Max index to use for the train set.",
)
Expand All @@ -133,7 +133,7 @@ def create_dataset(
min_lightness: int,
max_lightness: int,
bert_path: str,
max_train_size: int,
max_train_size: int | None,
domain_alignment: list[tuple[str, float]],
) -> None:
dataset_location = Path(output_path)
Expand Down
2 changes: 1 addition & 1 deletion simple_shapes_dataset/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
domain_classes: Mapping[DomainDesc, type[DataDomain]],
domain_proportions: Mapping[frozenset[str], float],
batch_size: int,
max_train_size: int | None = 500_000,
max_train_size: int | None = None,
num_workers: int = 0,
seed: int | None = None,
ood_seed: int | None = None,
Expand Down
4 changes: 2 additions & 2 deletions simple_shapes_dataset/domain_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_alignment(
split: str,
domain_proportions: Mapping[frozenset[str], float],
seed: int,
max_size: int | None = 500_000,
max_size: int | None,
) -> Mapping[frozenset[str], np.ndarray]:
assert split in ["train", "val", "test"]

Expand Down Expand Up @@ -55,7 +55,7 @@ def get_aligned_datasets(
domain_classes: Mapping[DomainDesc, type[DataDomain]],
domain_proportions: Mapping[frozenset[str], float],
seed: int,
max_size: int | None = 500_000,
max_size: int | None = None,
transforms: Mapping[str, Callable[[Any], Any]] | None = None,
domain_args: Mapping[str, Any] | None = None,
) -> dict[frozenset[str], Subset]:
Expand Down

0 comments on commit 79c6d71

Please sign in to comment.