diff --git a/nemo/collections/tts/parts/utils/callbacks.py b/nemo/collections/tts/parts/utils/callbacks.py index 63e3d424e4dd..2320e5b21a7c 100644 --- a/nemo/collections/tts/parts/utils/callbacks.py +++ b/nemo/collections/tts/parts/utils/callbacks.py @@ -66,11 +66,13 @@ def _load_vocoder(model_name: Optional[str], checkpoint_path: Optional[str], typ raise ValueError(f"Unknown vocoder type '{type}'") if model_name is not None: - vocoder = model_type.from_pretrained(model_name).eval() + vocoder = model_type.from_pretrained(model_name) + elif checkpoint_path.endswith(".nemo"): + vocoder = model_type.restore_from(checkpoint_path) else: - vocoder = model_type.load_from_checkpoint(checkpoint_path).eval() + vocoder = model_type.load_from_checkpoint(checkpoint_path) - return vocoder + return vocoder.eval() @dataclass