Skip to content

Conversation

sidharthrajaram
Copy link

@sidharthrajaram sidharthrajaram commented Dec 7, 2023

Downloading and caching ability when loading checkpoints (similar to Transformers API).

  • Can now provide a checkpoint filename or a checkpoint URL (or just a file path as before). Regardless, value passed to checkpoint_path in Phonemizer.from_checkpoint() should end with .pt. See test results below.
  • If just a checkpoint name is provided, it will try to retrieve the checkpoint under the DEFAULT_MODEL_BUCKET (see dp/model/model.py).
  • Caching is facilitated with cached_path

Why?

  • Convenient checkpoint loading (familiar behavior to what we're used to with Transformers/HuggingFace).
  • No longer have to manage location of checkpoints.
  • Enables easier model fetching (if your own trained checkpoints are on S3, MinIO, etc.)

Test cases:

Invalid file

>>> phonemizer = Phonemizer.from_checkpoint('test')
ValueError: test is not a valid model file (.pt).

Invalid checkpoint file

>>> phonemizer = Phonemizer.from_checkpoint('test.pt')
...
requests.exceptions.HTTPError: 403 Client Error: Forbidden for url: https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/test.pt

Checkpoint name provided

>>> phonemizer = Phonemizer.from_checkpoint('en_us_cmudict_forward.pt')
Loading model from /PATH/TO/.cache/cached_path/6c84425c...
>>>

Cached model loaded

>>> phonemizer = Phonemizer.from_checkpoint('en_us_cmudict_forward.pt')
Loading model from /PATH/TO/.cache/cached_path/6c84425c...
>>>

URL to checkpoint provided

>>> phonemizer = Phonemizer.from_checkpoint('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')
Loading model from /PATH/TO/.cache/cached_path/3f662135...
>>>

Cached model loaded (with just checkpoint name)

>>> phonemizer = Phonemizer.from_checkpoint('en_us_cmudict_ipa_forward.pt')
en_us_cmudict_ipa_forward.pt already exists in cache.
Loading model from /PATH/TO/.cache/cached_path/3f662135...
>>>

Cached model loaded (with checkpoint URL)

>>> phonemizer = Phonemizer.from_checkpoint('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')
en_us_cmudict_ipa_forward.pt already exists in cache.
Loading model from /PATH/TO/.cache/cached_path/3f662135...
>>>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant