Skip to content

Commit

Permalink
feat: add --force to the download commands
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Jan 29, 2025
1 parent 1833031 commit d8cfce6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
29 changes: 27 additions & 2 deletions shimmer_ssd/cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,20 @@ def download_group():
default="./checkpoints",
help="Where to download the checkpoints. Defaults to `./checkpoints`",
)
def download_dataset(path: Path):
@click.option(
"--force",
is_flag=True,
default=False,
type=bool,
help="If the file already exist, his will override with a new file.",
)
def download_dataset(path: Path, force: bool = False):
click.echo(f"Downloading in {str(path)}.")
if path.exists() and not force:
click.echo("Checkpoint path already exists. Skipping.")
return
elif path.exists():
click.echo("Checkpoint path already exists. Overriding.")
path.mkdir(exist_ok=True)
archive_path = path / "simple_shapes_checkpoints.tar.gz"
downlad_file(CHECKPOINTS_URL, archive_path)
Expand All @@ -45,8 +57,21 @@ def download_dataset(path: Path):
default="./tokenizer",
help="Where to download the tokenizer files",
)
def download_tokenizer(path: Path):
@click.option(
"--force",
is_flag=True,
default=False,
type=bool,
help="If the file already exist, his will override with a new file.",
)
def download_tokenizer(path: Path, force: bool = False):
click.echo(f"Downloading in {str(path)}.")
click.echo(f"Downloading in {str(path)}.")
if path.exists() and not force:
click.echo("Tokenizer path already exists. Skipping.")
return
elif path.exists():
click.echo("Tokenizer path already exists. Overriding.")
path.mkdir(exist_ok=True)
downlad_file(TOKENIZER_URL + "/merges.txt", path / "merges.txt")
downlad_file(TOKENIZER_URL + "/vocab.json", path / "vocab.json")
2 changes: 1 addition & 1 deletion shimmer_ssd/cli/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def save_v_latents(
latent_name = latent_name or (checkpoin_path.stem + ".npy")
train_path = dataset_path / f"saved_latents/train/{latent_name}"
if train_path.exists() and not force:
click.echo("Latent file already exists. Skipping")
click.echo("Latent file already exists. Skipping.")
return
elif train_path.exists():
click.echo("Latent file already exists. Overriding.")
Expand Down

0 comments on commit d8cfce6

Please sign in to comment.