From 1793fe5e95f8762a6699529391f3356bd650cba7 Mon Sep 17 00:00:00 2001 From: Hubert Siuzdak <35269911+hubertsiuzdak@users.noreply.github.com> Date: Fri, 13 Oct 2023 21:38:43 +0200 Subject: [PATCH] v0.0.4 (#29) * Replace 'Self' type (requires python>=3.11); Add optional revision to hf_hub_download * Bump version --- vocos/__init__.py | 2 +- vocos/pretrained.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vocos/__init__.py b/vocos/__init__.py index f363437..093920f 100644 --- a/vocos/__init__.py +++ b/vocos/__init__.py @@ -1,4 +1,4 @@ from vocos.pretrained import Vocos -__version__ = "0.0.3" +__version__ = "0.0.4" diff --git a/vocos/pretrained.py b/vocos/pretrained.py index 6f5cde1..a8a5935 100644 --- a/vocos/pretrained.py +++ b/vocos/pretrained.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, Self, Tuple, Union +from typing import Any, Dict, Tuple, Union, Optional import torch import yaml @@ -47,7 +47,7 @@ def __init__( self.head = head @classmethod - def from_hparams(cls, config_path: str) -> Self: + def from_hparams(cls, config_path: str) -> Vocos: """ Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. """ @@ -60,12 +60,12 @@ def from_hparams(cls, config_path: str) -> Self: return model @classmethod - def from_pretrained(cls, repo_id: str) -> Self: + def from_pretrained(cls, repo_id: str, revision: Optional[str] = None) -> Vocos: """ Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. """ - config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml") - model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") + config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml", revision=revision) + model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", revision=revision) model = cls.from_hparams(config_path) state_dict = torch.load(model_path, map_location="cpu") if isinstance(model.feature_extractor, EncodecFeatures):