diff --git a/assets/mel_filters.npz b/assets/mel_filters.npz index 913b906..b2f036a 100644 Binary files a/assets/mel_filters.npz and b/assets/mel_filters.npz differ diff --git a/extract_filterbank.py b/extract_filterbank.py index 165c0bb..4e1a96e 100644 --- a/extract_filterbank.py +++ b/extract_filterbank.py @@ -1,15 +1,8 @@ -import librosa +import mlx.core as mx import numpy as np -filterbank = librosa.filters.mel( - sr=24000, - n_fft=1024, - n_mels=100, - norm = None, - htk = True -) +from huggingface_hub import hf_hub_download -np.savez_compressed( - "assets/mel_filters.npz", - mel_100=filterbank -) +path = hf_hub_download("lucasnewman/vocos-mel-24khz", "model.safetensors") +filterbank = mx.load(path)["feature_extractor.mel_spec.mel_scale.fb"].moveaxis(0, 1) +np.savez_compressed("assets/mel_filters.npz", mel_100=filterbank) diff --git a/pyproject.toml b/pyproject.toml index 1e65812..4888dda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [build-system] requires = [ "huggingface_hub", - "mlx>=0.18.0", + "mlx", "numpy", "pyyaml", "setuptools", @@ -10,7 +10,7 @@ build-backend = "setuptools.build_meta" [project] name = "vocos-mlx" -version = "0.0.4" +version = "0.0.5" authors = [{name = "Lucas Newman", email = "lucasnewman@me.com"}] license = {text = "MIT"} description = "Vocos - MLX" diff --git a/vocos_mlx/model.py b/vocos_mlx/model.py index de44c1d..63e7d1c 100644 --- a/vocos_mlx/model.py +++ b/vocos_mlx/model.py @@ -20,7 +20,7 @@ def mel_filters(n_mels: int) -> mx.array: """ load the mel filterbank matrix for projecting STFT into a Mel spectrogram. - Allows decoupling librosa dependency; saved using extract_filterbank.py + Saved using extract_filterbank.py """ assert n_mels in {100}, f"Unsupported n_mels: {n_mels}" @@ -93,6 +93,7 @@ def log_mel_spectrogram( n_fft: int = 1024, hop_length: int = 256, padding: int = 0, + filterbank: Optional[mx.array] = None, ): """ Compute the log-Mel spectrogram of @@ -121,7 +122,7 @@ def log_mel_spectrogram( freqs = stft(audio, hanning(n_fft), nperseg=n_fft, noverlap=hop_length) magnitudes = freqs[:-1, :].abs() - filters = mel_filters(n_mels) + filters = filterbank if filterbank is not None else mel_filters(n_mels) mel_spec = magnitudes @ filters.T log_spec = mx.maximum(mel_spec, 1e-5).log() return mx.expand_dims(log_spec, axis=0) @@ -137,20 +138,30 @@ def __call__(self, audio: mx.array, **kwargs) -> mx.array: class MelSpectrogramFeatures(FeatureExtractor): def __init__( self, - sample_rate=24000, + sample_rate=24_000, n_fft=1024, hop_length=256, n_mels=100, padding="center", + filterbank: Optional[mx.array] = None, ): super().__init__() if padding not in ["center", "same"]: raise ValueError("Padding must be 'center' or 'same'.") self.padding = padding + self.n_fft = n_fft + self.hop_length = hop_length + self.n_mels = n_mels + self.filterbank = filterbank def __call__(self, audio, **kwargs): return log_mel_spectrogram( - audio, n_mels=100, n_fft=1024, hop_length=256, padding=0 + audio, + n_mels=self.n_mels, + n_fft=self.n_fft, + hop_length=self.hop_length, + padding=0, + filterbank=self.filterbank, ) @@ -433,7 +444,9 @@ def __init__( self.head = head @classmethod - def from_hparams(cls, config_path: str) -> Vocos: + def from_hparams( + cls, config_path: str, filterbank: Optional[mx.array] = None + ) -> Vocos: """ Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. """ @@ -441,9 +454,10 @@ def from_hparams(cls, config_path: str) -> Vocos: config = SimpleNamespace(**yaml.safe_load(f)) if "MelSpectrogramFeatures" in config.feature_extractor["class_path"]: - feature_extractor = MelSpectrogramFeatures( - **config.feature_extractor["init_args"] - ) + feature_extractor_init_args = config.feature_extractor["init_args"] + if filterbank is not None: + feature_extractor_init_args["filterbank"] = filterbank + feature_extractor = MelSpectrogramFeatures(**feature_extractor_init_args) elif "EncodecFeatures" in config.feature_extractor["class_path"]: feature_extractor = EncodecFeatures(**config.feature_extractor["init_args"]) backbone = VocosBackbone(**config.backbone["init_args"]) @@ -466,22 +480,29 @@ def from_pretrained(cls, path_or_repo: str) -> Vocos: ) ) - config_path = path / "config.yaml" - model = cls.from_hparams(config_path) - model_path = path / "model.safetensors" with open(model_path, "rb") as f: weights = mx.load(f) + # load the filterbank for model initialization + + try: + filterbank = weights.pop( + "feature_extractor.mel_spec.mel_scale.fb" + ).moveaxis(0, 1) + except KeyError: + filterbank = None + + config_path = path / "config.yaml" + model = cls.from_hparams(config_path, filterbank) + # remove unused weights try: del weights["feature_extractor.mel_spec.spectrogram.window"] - del weights["feature_extractor.mel_spec.mel_scale.fb"] + del weights["head.istft.window"] except KeyError: pass - del weights["head.istft.window"] - # transpose weights as needed new_weights = {} for k, v in weights.items(): @@ -504,6 +525,12 @@ def __call__(self, audio_input: mx.array, **kwargs: Any) -> mx.array: audio_output = self.decode(features, **kwargs) return audio_output + def get_encodec_codes(self, audio_input: mx.array, bandwidth_id: int) -> mx.array: + if not isinstance(self.feature_extractor, EncodecFeatures): + raise ValueError("This model does not support getting encodec codes.") + + return self.feature_extractor.get_encodec_codes(audio_input, bandwidth_id) + def decode(self, features_input: mx.array, **kwargs: Any) -> mx.array: x = self.backbone(features_input, **kwargs) audio_output = self.head(x)