diff --git a/assets/mel_filters.npz b/assets/mel_filters.npz deleted file mode 100644 index b2f036a..0000000 Binary files a/assets/mel_filters.npz and /dev/null differ diff --git a/extract_filterbank.py b/extract_filterbank.py deleted file mode 100644 index 4e1a96e..0000000 --- a/extract_filterbank.py +++ /dev/null @@ -1,8 +0,0 @@ -import mlx.core as mx -import numpy as np - -from huggingface_hub import hf_hub_download - -path = hf_hub_download("lucasnewman/vocos-mel-24khz", "model.safetensors") -filterbank = mx.load(path)["feature_extractor.mel_spec.mel_scale.fb"].moveaxis(0, 1) -np.savez_compressed("assets/mel_filters.npz", mel_100=filterbank) diff --git a/vocos_mlx/mel.py b/vocos_mlx/mel.py new file mode 100644 index 0000000..0c50591 --- /dev/null +++ b/vocos_mlx/mel.py @@ -0,0 +1,168 @@ +from __future__ import annotations +from functools import lru_cache +import math +from typing import Optional + +import mlx.core as mx + + +@lru_cache(maxsize=None) +def mel_filters( + sample_rate: int, + n_fft: int, + n_mels: int, + f_min: float = 0, + f_max: Optional[float] = None, + norm: Optional[str] = None, + mel_scale: str = "htk", +) -> mx.array: + def hz_to_mel(freq, mel_scale="htk"): + if mel_scale == "htk": + return 2595.0 * math.log10(1.0 + freq / 700.0) + + # slaney scale + f_min, f_sp = 0.0, 200.0 / 3 + mels = (freq - f_min) / f_sp + min_log_hz = 1000.0 + min_log_mel = (min_log_hz - f_min) / f_sp + logstep = math.log(6.4) / 27.0 + if freq >= min_log_hz: + mels = min_log_mel + math.log(freq / min_log_hz) / logstep + return mels + + def mel_to_hz(mels, mel_scale="htk"): + if mel_scale == "htk": + return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) + + # slaney scale + f_min, f_sp = 0.0, 200.0 / 3 + freqs = f_min + f_sp * mels + min_log_hz = 1000.0 + min_log_mel = (min_log_hz - f_min) / f_sp + logstep = math.log(6.4) / 27.0 + log_t = mels >= min_log_mel + freqs[log_t] = min_log_hz * mx.exp(logstep * (mels[log_t] - min_log_mel)) + return freqs + + f_max = f_max or sample_rate / 2 + + # generate frequency points + + n_freqs = n_fft // 2 + 1 + all_freqs = mx.linspace(0, sample_rate // 2, n_freqs) + + # convert frequencies to mel and back to hz + + m_min = hz_to_mel(f_min, mel_scale) + m_max = hz_to_mel(f_max, mel_scale) + m_pts = mx.linspace(m_min, m_max, n_mels + 2) + f_pts = mel_to_hz(m_pts, mel_scale) + + # compute slopes for filterbank + + f_diff = f_pts[1:] - f_pts[:-1] + slopes = mx.expand_dims(f_pts, 0) - mx.expand_dims(all_freqs, 1) + + # calculate overlapping triangular filters + + down_slopes = (-slopes[:, :-2]) / f_diff[:-1] + up_slopes = slopes[:, 2:] / f_diff[1:] + filterbank = mx.maximum( + mx.zeros_like(down_slopes), mx.minimum(down_slopes, up_slopes) + ) + + if norm == "slaney": + enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels]) + filterbank *= mx.expand_dims(enorm, 0) + + filterbank = filterbank.moveaxis(0, 1) + return filterbank + + +@lru_cache(maxsize=None) +def hanning(size): + return mx.array( + [0.5 * (1 - math.cos(2 * math.pi * n / (size - 1))) for n in range(size)] + ) + + +def stft(x, window, nperseg=256, noverlap=None, nfft=None, pad_mode="constant"): + if nfft is None: + nfft = nperseg + if noverlap is None: + noverlap = nfft // 4 + + def _pad(x, padding, pad_mode="constant"): + if pad_mode == "constant": + return mx.pad(x, [(padding, padding)]) + elif pad_mode == "reflect": + prefix = x[1 : padding + 1][::-1] + suffix = x[-(padding + 1) : -1][::-1] + return mx.concatenate([prefix, x, suffix]) + else: + raise ValueError(f"Invalid pad_mode {pad_mode}") + + padding = nperseg // 2 + x = _pad(x, padding, pad_mode) + + strides = [noverlap, 1] + t = (x.size - nperseg + noverlap) // noverlap + shape = [t, nfft] + x = mx.as_strided(x, shape=shape, strides=strides) + return mx.fft.rfft(x * window) + + +def istft(x, window, nperseg=256, noverlap=None, nfft=None): + if nfft is None: + nfft = nperseg + if noverlap is None: + noverlap = nfft // 4 + + t = (x.shape[0] - 1) * noverlap + nperseg + reconstructed = mx.zeros(t) + window_sum = mx.zeros(t) + + for i in range(x.shape[0]): + # inverse FFT of each frame + frame_time = mx.fft.irfft(x[i]) + + # get the position in the time-domain signal to add the frame + start = i * noverlap + end = start + nperseg + + # overlap-add the inverse transformed frame, scaled by the window + reconstructed[start:end] += frame_time * window + window_sum[start:end] += window + + # normalize by the sum of the window values + reconstructed = mx.where(window_sum != 0, reconstructed / window_sum, reconstructed) + + return reconstructed + + +def log_mel_spectrogram( + audio: mx.array, + sample_rate: int = 24_000, + n_mels: int = 100, + n_fft: int = 1024, + hop_length: int = 256, + padding: int = 0, +): + if not isinstance(audio, mx.array): + audio = mx.array(audio) + + if padding > 0: + audio = mx.pad(audio, (0, padding)) + + freqs = stft(audio, hanning(n_fft), nperseg=n_fft, noverlap=hop_length) + magnitudes = freqs[:-1, :].abs() + filters = mel_filters( + sample_rate=sample_rate, + n_fft=n_fft, + n_mels=n_mels, + norm=None, + mel_scale="htk", + ) + mel_spec = magnitudes @ filters.T + log_spec = mx.maximum(mel_spec, 1e-5).log() + return mx.expand_dims(log_spec, axis=0) diff --git a/vocos_mlx/model.py b/vocos_mlx/model.py index 64f5186..6b64ad7 100644 --- a/vocos_mlx/model.py +++ b/vocos_mlx/model.py @@ -1,7 +1,4 @@ from __future__ import annotations -from functools import lru_cache -import math -import os from pathlib import Path from typing import Any, List, Optional from types import SimpleNamespace @@ -13,103 +10,7 @@ import yaml from vocos_mlx.encodec import EncodecModel - - -@lru_cache(maxsize=None) -def mel_filters(n_mels: int) -> mx.array: - """ - load the mel filterbank matrix for projecting STFT into a Mel spectrogram. - Saved using extract_filterbank.py - """ - assert n_mels in {100}, f"Unsupported n_mels: {n_mels}" - - filename = os.path.join("assets", "mel_filters.npz") - return mx.load(filename, format="npz")[f"mel_{n_mels}"] - - -@lru_cache(maxsize=None) -def hanning(size): - return mx.array( - [0.5 * (1 - math.cos(2 * math.pi * n / (size - 1))) for n in range(size)] - ) - - -def stft(x, window, nperseg=256, noverlap=None, nfft=None, pad_mode="constant"): - if nfft is None: - nfft = nperseg - if noverlap is None: - noverlap = nfft // 4 - - def _pad(x, padding, pad_mode="constant"): - if pad_mode == "constant": - return mx.pad(x, [(padding, padding)]) - elif pad_mode == "reflect": - prefix = x[1 : padding + 1][::-1] - suffix = x[-(padding + 1) : -1][::-1] - return mx.concatenate([prefix, x, suffix]) - else: - raise ValueError(f"Invalid pad_mode {pad_mode}") - - padding = nperseg // 2 - x = _pad(x, padding, pad_mode) - - strides = [noverlap, 1] - t = (x.size - nperseg + noverlap) // noverlap - shape = [t, nfft] - x = mx.as_strided(x, shape=shape, strides=strides) - return mx.fft.rfft(x * window) - - -def istft(x, window, nperseg=256, noverlap=None, nfft=None): - if nfft is None: - nfft = nperseg - if noverlap is None: - noverlap = nfft // 4 - - t = (x.shape[0] - 1) * noverlap + nperseg - reconstructed = mx.zeros(t) - window_sum = mx.zeros(t) - - for i in range(x.shape[0]): - # inverse FFT of each frame - frame_time = mx.fft.irfft(x[i]) - - # get the position in the time-domain signal to add the frame - start = i * noverlap - end = start + nperseg - - # overlap-add the inverse transformed frame, scaled by the window - reconstructed[start:end] += frame_time * window - window_sum[start:end] += window - - # normalize by the sum of the window values - reconstructed = mx.where(window_sum != 0, reconstructed / window_sum, reconstructed) - - return reconstructed - - -def log_mel_spectrogram( - audio: mx.array, - n_mels: int = 100, - n_fft: int = 1024, - hop_length: int = 256, - padding: int = 0, - filterbank: Optional[mx.array] = None, -): - if not isinstance(audio, mx.array): - audio = mx.array(audio) - - if padding > 0: - audio = mx.pad(audio, (0, padding)) - - freqs = stft(audio, hanning(n_fft), nperseg=n_fft, noverlap=hop_length) - magnitudes = freqs[:-1, :].abs() - filters = filterbank if filterbank is not None else mel_filters(n_mels) - mel_spec = magnitudes @ filters.T - log_spec = mx.maximum(mel_spec, 1e-5).log() - return mx.expand_dims(log_spec, axis=0) - - +from vocos_mlx.mel import log_mel_spectrogram, istft, hanning class FeatureExtractor(nn.Module): """Base class for feature extractors.""" @@ -125,25 +26,24 @@ def __init__( hop_length=256, n_mels=100, padding="center", - filterbank: Optional[mx.array] = None, ): super().__init__() if padding not in ["center", "same"]: raise ValueError("Padding must be 'center' or 'same'.") self.padding = padding + self.sample_rate = sample_rate self.n_fft = n_fft self.hop_length = hop_length self.n_mels = n_mels - self.filterbank = filterbank def __call__(self, audio: mx.array, **kwargs): return log_mel_spectrogram( audio, + sample_rate=self.sample_rate, n_mels=self.n_mels, n_fft=self.n_fft, hop_length=self.hop_length, - padding=0, - filterbank=self.filterbank, + padding=0 ) @@ -352,7 +252,7 @@ def __init__( @classmethod def from_hparams( - cls, config_path: str, filterbank: Optional[mx.array] = None + cls, config_path: str ) -> Vocos: """ Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. @@ -362,8 +262,6 @@ def from_hparams( if "MelSpectrogramFeatures" in config.feature_extractor["class_path"]: feature_extractor_init_args = config.feature_extractor["init_args"] - if filterbank is not None: - feature_extractor_init_args["filterbank"] = filterbank feature_extractor = MelSpectrogramFeatures(**feature_extractor_init_args) elif "EncodecFeatures" in config.feature_extractor["class_path"]: feature_extractor = EncodecFeatures(**config.feature_extractor["init_args"]) @@ -391,17 +289,8 @@ def from_pretrained(cls, path_or_repo: str) -> Vocos: with open(model_path, "rb") as f: weights = mx.load(f) - # load the filterbank for model initialization - - try: - filterbank = weights.pop( - "feature_extractor.mel_spec.mel_scale.fb" - ).moveaxis(0, 1) - except KeyError: - filterbank = None - config_path = path / "config.yaml" - model = cls.from_hparams(config_path, filterbank) + model = cls.from_hparams(config_path) # remove unused weights try: