Skip to content

Commit

Permalink
[TTS] Support .nemo checkpoint in FP callback
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan <[email protected]>
  • Loading branch information
rlangman committed May 25, 2023
1 parent 37f75d6 commit f9988d0
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions nemo/collections/tts/parts/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ def _load_vocoder(model_name: Optional[str], checkpoint_path: Optional[str], typ
raise ValueError(f"Unknown vocoder type '{type}'")

if model_name is not None:
vocoder = model_type.from_pretrained(model_name).eval()
vocoder = model_type.from_pretrained(model_name)
elif checkpoint_path.endswith(".nemo"):
vocoder = model_type.restore_from(checkpoint_path)
else:
vocoder = model_type.load_from_checkpoint(checkpoint_path).eval()
vocoder = model_type.load_from_checkpoint(checkpoint_path)

return vocoder
return vocoder.eval()


@dataclass
Expand Down

0 comments on commit f9988d0

Please sign in to comment.