Skip to content

Commit

Permalink
fix: remove the use of max_train_size so that scripts work on all siz…
Browse files Browse the repository at this point in the history
…ed datasets
  • Loading branch information
bdvllrs committed Jan 30, 2025
1 parent 09bf329 commit 6038a5e
Show file tree
Hide file tree
Showing 10 changed files with 2 additions and 18 deletions.
7 changes: 1 addition & 6 deletions docs/config_parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ dataset:
# Path to the simple-shapes-dataset. Can be downloaded with `shapesd download`
path: "./simple_shapes_dataset" # (type: Path)
# Max number of unpaired examples used during training.
# This is here for legacy reasons. Prefer changing `domain_proportions`.
# The proportion is relative to this value.
max_train_size: 500_000 # (type: int | None)
training:
batch_size: 2056 # (type: int)
num_workers: 16 # (type: int)
Expand Down Expand Up @@ -125,7 +120,7 @@ title: null # (type: str | None)
# alias `d`
desc: null # (type: str | None)
# Proportion of each domain in the dataset relative to `dataset.max_train_size`
# Proportion of each domain in the dataset relative to the size of the dataset
domain_proportions: [] # (type: Sequence[DomainProportion])
# For example:
# domain_proportions:
Expand Down
1 change: 0 additions & 1 deletion scripts/exploration/test_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def main():
get_default_domains(["t", "v"]),
{frozenset(["t"]): 1.0, frozenset(["v"]): 1.0},
batch_size=config.training.batch_size,
max_train_size=config.dataset.max_train_size,
num_workers=config.training.num_workers,
domain_args={
"t": {"latent_filename": config.domain_modules.text.latent_filename}
Expand Down
1 change: 0 additions & 1 deletion scripts/save_v_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def main():
get_default_domains(["v"]),
{frozenset(["v"]): 1.0},
batch_size=config.training.batch_size,
max_train_size=config.dataset.max_train_size,
num_workers=config.training.num_workers,
seed=config.seed,
additional_transforms=additional_transforms,
Expand Down
1 change: 0 additions & 1 deletion scripts/train_text2attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def main():
domain_classes,
domain_proportion,
batch_size=config.training.batch_size,
max_train_size=config.dataset.max_train_size,
num_workers=config.training.num_workers,
seed=config.seed,
ood_seed=config.ood_seed,
Expand Down
1 change: 0 additions & 1 deletion shimmer_ssd/cli/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def save_v_latents(
get_default_domains(["v"]),
{frozenset(["v"]): 1.0},
batch_size=config.training.batch_size,
max_train_size=config.dataset.max_train_size,
num_workers=config.training.num_workers,
seed=config.seed,
additional_transforms=additional_transforms,
Expand Down
1 change: 0 additions & 1 deletion shimmer_ssd/cli/train_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def train_attr_domain(
get_default_domains(["attr"]),
{frozenset(["attr"]): 1.0},
batch_size=config.training.batch_size,
max_train_size=config.dataset.max_train_size,
num_workers=config.training.num_workers,
)

Expand Down
1 change: 0 additions & 1 deletion shimmer_ssd/cli/train_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def train_gw(
domain_classes,
config.domain_proportions,
batch_size=config.training.batch_size,
max_train_size=config.dataset.max_train_size,
num_workers=config.training.num_workers,
seed=config.seed,
ood_seed=config.ood_seed,
Expand Down
1 change: 0 additions & 1 deletion shimmer_ssd/cli/train_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def train_t_domain(
get_default_domains(["t"]),
{frozenset(["t"]): 1.0},
batch_size=config.training.batch_size,
max_train_size=config.dataset.max_train_size,
num_workers=config.training.num_workers,
domain_args={
"t": {"latent_filename": config.domain_modules.text.latent_filename}
Expand Down
1 change: 0 additions & 1 deletion shimmer_ssd/cli/train_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def train_visual_domain(
get_default_domains(["v"]),
{frozenset(["v"]): 1.0},
batch_size=config.training.batch_size,
max_train_size=config.dataset.max_train_size,
num_workers=config.training.num_workers,
additional_transforms=additional_transforms,
)
Expand Down
5 changes: 1 addition & 4 deletions shimmer_ssd/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,6 @@ class Dataset(BaseModel):

# Path to the dataset obtainable on https://github.com/ruflab/simple-shapes-dataset
path: Path
# Max number of unpaired examples used during training.
# Prefer changing `domain_proportions`. The proportion is relative to this value.
max_train_size: int | None = 500_000


class VisualModule(BaseModel):
Expand Down Expand Up @@ -426,7 +423,7 @@ class Config(ParsedModel):
title: str | None = Field(None, alias="t")
# Add a description to your run
desc: str | None = Field(None, alias="d")
# proportion of each domain in the dataset relative to `dataset.max_train_size`
# proportion of each domain in the dataset relative to the size of the dataset
domain_proportions: Mapping[frozenset[str], float] = {}
# Config of the different domain modules
domain_modules: DomainModules = DomainModules()
Expand Down

0 comments on commit 6038a5e

Please sign in to comment.