Skip to content

Commit

Permalink
Remove external filterbank dependency.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman committed Sep 30, 2024
1 parent 1460718 commit 7e79b2c
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 28 deletions.
Binary file modified assets/mel_filters.npz
Binary file not shown.
17 changes: 5 additions & 12 deletions extract_filterbank.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
import librosa
import mlx.core as mx
import numpy as np

filterbank = librosa.filters.mel(
sr=24000,
n_fft=1024,
n_mels=100,
norm = None,
htk = True
)
from huggingface_hub import hf_hub_download

np.savez_compressed(
"assets/mel_filters.npz",
mel_100=filterbank
)
path = hf_hub_download("lucasnewman/vocos-mel-24khz", "model.safetensors")
filterbank = mx.load(path)["feature_extractor.mel_spec.mel_scale.fb"].moveaxis(0, 1)
np.savez_compressed("assets/mel_filters.npz", mel_100=filterbank)
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[build-system]
requires = [
"huggingface_hub",
"mlx>=0.18.0",
"mlx",
"numpy",
"pyyaml",
"setuptools",
Expand All @@ -10,7 +10,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "vocos-mlx"
version = "0.0.4"
version = "0.0.5"
authors = [{name = "Lucas Newman", email = "[email protected]"}]
license = {text = "MIT"}
description = "Vocos - MLX"
Expand Down
55 changes: 41 additions & 14 deletions vocos_mlx/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
def mel_filters(n_mels: int) -> mx.array:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using extract_filterbank.py
Saved using extract_filterbank.py
"""
assert n_mels in {100}, f"Unsupported n_mels: {n_mels}"

Expand Down Expand Up @@ -93,6 +93,7 @@ def log_mel_spectrogram(
n_fft: int = 1024,
hop_length: int = 256,
padding: int = 0,
filterbank: Optional[mx.array] = None,
):
"""
Compute the log-Mel spectrogram of
Expand Down Expand Up @@ -121,7 +122,7 @@ def log_mel_spectrogram(

freqs = stft(audio, hanning(n_fft), nperseg=n_fft, noverlap=hop_length)
magnitudes = freqs[:-1, :].abs()
filters = mel_filters(n_mels)
filters = filterbank if filterbank is not None else mel_filters(n_mels)
mel_spec = magnitudes @ filters.T
log_spec = mx.maximum(mel_spec, 1e-5).log()
return mx.expand_dims(log_spec, axis=0)
Expand All @@ -137,20 +138,30 @@ def __call__(self, audio: mx.array, **kwargs) -> mx.array:
class MelSpectrogramFeatures(FeatureExtractor):
def __init__(
self,
sample_rate=24000,
sample_rate=24_000,
n_fft=1024,
hop_length=256,
n_mels=100,
padding="center",
filterbank: Optional[mx.array] = None,
):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.n_fft = n_fft
self.hop_length = hop_length
self.n_mels = n_mels
self.filterbank = filterbank

def __call__(self, audio, **kwargs):
return log_mel_spectrogram(
audio, n_mels=100, n_fft=1024, hop_length=256, padding=0
audio,
n_mels=self.n_mels,
n_fft=self.n_fft,
hop_length=self.hop_length,
padding=0,
filterbank=self.filterbank,
)


Expand Down Expand Up @@ -433,17 +444,20 @@ def __init__(
self.head = head

@classmethod
def from_hparams(cls, config_path: str) -> Vocos:
def from_hparams(
cls, config_path: str, filterbank: Optional[mx.array] = None
) -> Vocos:
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with open(config_path, "r") as f:
config = SimpleNamespace(**yaml.safe_load(f))

if "MelSpectrogramFeatures" in config.feature_extractor["class_path"]:
feature_extractor = MelSpectrogramFeatures(
**config.feature_extractor["init_args"]
)
feature_extractor_init_args = config.feature_extractor["init_args"]
if filterbank is not None:
feature_extractor_init_args["filterbank"] = filterbank
feature_extractor = MelSpectrogramFeatures(**feature_extractor_init_args)
elif "EncodecFeatures" in config.feature_extractor["class_path"]:
feature_extractor = EncodecFeatures(**config.feature_extractor["init_args"])
backbone = VocosBackbone(**config.backbone["init_args"])
Expand All @@ -466,22 +480,29 @@ def from_pretrained(cls, path_or_repo: str) -> Vocos:
)
)

config_path = path / "config.yaml"
model = cls.from_hparams(config_path)

model_path = path / "model.safetensors"
with open(model_path, "rb") as f:
weights = mx.load(f)

# load the filterbank for model initialization

try:
filterbank = weights.pop(
"feature_extractor.mel_spec.mel_scale.fb"
).moveaxis(0, 1)
except KeyError:
filterbank = None

config_path = path / "config.yaml"
model = cls.from_hparams(config_path, filterbank)

# remove unused weights
try:
del weights["feature_extractor.mel_spec.spectrogram.window"]
del weights["feature_extractor.mel_spec.mel_scale.fb"]
del weights["head.istft.window"]
except KeyError:
pass

del weights["head.istft.window"]

# transpose weights as needed
new_weights = {}
for k, v in weights.items():
Expand All @@ -504,6 +525,12 @@ def __call__(self, audio_input: mx.array, **kwargs: Any) -> mx.array:
audio_output = self.decode(features, **kwargs)
return audio_output

def get_encodec_codes(self, audio_input: mx.array, bandwidth_id: int) -> mx.array:
if not isinstance(self.feature_extractor, EncodecFeatures):
raise ValueError("This model does not support getting encodec codes.")

return self.feature_extractor.get_encodec_codes(audio_input, bandwidth_id)

def decode(self, features_input: mx.array, **kwargs: Any) -> mx.array:
x = self.backbone(features_input, **kwargs)
audio_output = self.head(x)
Expand Down

0 comments on commit 7e79b2c

Please sign in to comment.