From 37d09cbb5c08ad5c0830f2332e5a67f9e9e62439 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Wed, 30 Oct 2024 15:13:57 -0700 Subject: [PATCH] Compute mel filterbanks instead of extracting it from the model weights. --- assets/mel_filters.npz | Bin 5339 -> 0 bytes extract_filterbank.py | 8 -- vocos_mlx/mel.py | 168 +++++++++++++++++++++++++++++++++++++++++ vocos_mlx/model.py | 123 ++---------------------------- 4 files changed, 174 insertions(+), 125 deletions(-) delete mode 100644 assets/mel_filters.npz delete mode 100644 extract_filterbank.py create mode 100644 vocos_mlx/mel.py diff --git a/assets/mel_filters.npz b/assets/mel_filters.npz deleted file mode 100644 index b2f036aeb2de4aa51da1f52bb704771dfe97c836..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5339 zcmZ{IXIxWHuy#UM6r$2Z2oX?<6zR<r6hiVWNmR9k8uKwZr-(l+)OcM<&XGjV2e0FUD?Qzlj zttZUUIb6H=6Fm;@3+%i2;Go$;drkWF(yYu)uHt-)lan#5!Kco@oHU<&wjOK8NgJJz z@hT@$zVY$-OU`ao{JZa)Ii}rhFCl_-O8ED;ugXI0j(YAgUCw%Z385NP9pBRDmxa(~ zu1R5sT29e)PZfGxwDuIsfCwpZA!vyWgyLTvbyvs*m{Cpwdw zG|+C78Hnm!;htOV&dY9@>CoH zru9p{voE{eZyBmQy9c7A?chFFn(lW3Mo*l7b04)H`To4Kq}%!_DoR)018OlXqw#~z zV;O!yZQ+g6*vxrRzZgh8JDU*e7n3FpWxoN%Qeh_ZZBFR?fvjQB-y(tx6>s5wrf}ML zCXNA?!)ou)XvS1_0oE5_Qt0ugNt2-mXBM9~(e>$9)f)^2do9`E7VfUWh;8Olp)Mhn zmcMP<5YBa&g=1it(N)~d@V^N(Pood$Znj|bi$R+!dZ!5+Tl>dAE_($O@LNkS$xzw3 zkGqf8tEk%WAvFH>4D(>qd*RA2XDsf$yp3xSG|c~3*vs<@NNC^OVOzm^WS0g2h=EisS zY#26FXxaUR=Ff;^8rg&^-`V9IL*zM9a_*O1xF!c?3v5{Pf0YG1>!*@zJ`ybSUy|<3 zpPSEOarZpiE9l>~?$_x&b>IcFqrP_xBiz@9&glt5jtaGdrcIYo}JRd5&T#Lr7tAa46hEqW7Q9I!1J-LX{>8-RVfG7O&IR3 zT>%hH1oMbhx!{1!^MmDVxxRP$;E2E$d~qqy^WUkoyww5dno7xRBs#Bzl#IJ<3uFnp zh1X{8|JyQWxw8EQ8*AzB$rfMVA{1{a5lc}(7r1W;k2?h3=F}ib$;>t|0qym@Uz=>g zc*qdR4Y$|kGpPTLZD8Lc2QNyyP7HH!spK%W>|G;!U!@J)fH zUHU|9+k|j-@G4@>32A85NSae;DUE3}4ty_%f;f`{qE^1lE?KZ)js^qLswA_(e(?u6 zr@@ zSFjQ9SHtkbzO>Q!QAWQ4joRvJfMnp`M;lPP)m;# z!I)%aRLGEu z0D`W`vx!7|Hij&fT;cmZPRy_il)uw0dbd!TmBU_`z~%8!AdrY6+Yu_mkjeyK|4%h) z<4dWucK^aZ7roFlBSVH;?F&thZbu=}Um4zBsCAt2i{rwmdwCd61&A%bfcZFH?%S!Q zOi{P9ck#;DOSkGo&v%-dnaDkbpdhO#L6fhp_{QUzVEQ9IGj_!L3u2k7%40T@dPSb> zmGqc}xSWD&pY`WVP>HIik?}EybYf9!Wu7jz)_ZPO!-t9D?s?ItO-_BiY{<6ss)8( z39}hp3f%2oW8sEno$kOf8*^t&QfVCq zW4k%CS%v;Q<)cCIWlv#&nOQq#*p!}A(#{GSFvIxDH9^(I?{LICpXpQfjA$oj#PLCw zmMPZS%05I(di`31XwRv}aqeQ8Zi7|ym@r1qLip^Sn_^V2O+9bn^5d5@ zD2V^lG{K&%hu$%?ynZpQ1cv}GJ$eJXj|qhWLAC(4*_MJGzP%C;mia8mm%!(M(bED;8^AQk*EX;LxitDj#0--`D)p3d^KnbL<%hW>`|o)%~g~Mw8T-5UwH>U)-5#M#{cu9O>{Z|QfMe4yopS$n$H21n=}*O*GJDa9huF9L2Fz(p~? z)7e88mX1f95<9^rOa^-KxLr*$gyf@iX^J#s4SsEsuxsfKmj+i6O}ESnqRPcKpu9fG zSw&XtfboMI(i2;A^C;FSn5lSGcalea+5k?I3}ci&2ShS6#bL;np0 z;Ox{_dJ3%nUDju*Axaq++cJ|Oe@=9ClJ2~t9DnG}_J4RHC=iK68y~_&MVWI5aGLgP z%EMLu{K49p)E&F)3p*K)&U!=#7b@tyOG>WpbX*^JU z$W5a@VIiRFB%g*v^UuU3koVT>=%E$YaQd(L_MgYl&OcNVAl)D|n+vXaNGCa|%YmGm z*^sYiEc-i?Y2J!rmBhXJ4UV8wr!Pu92!zvwS^x1EaPQ|JL*|Kc>ihI&JIMQp@Fk{X zb-w;O5R>*Fv?E%b0GxSOP3=Y~H&JMf^f1JyC3h7 z47_>rRZiHLgme(>$3~42=wvsbxw(}4)&J;PBO`K&XbqQxiL0Q#bT2LM8yV%7)4#Jh z>bkf4RrlAW*YdYMI7&c{Omu7=pTl= z+Q3#Tkc1>%I-u|~cTwz85-0+ek1j_Y^iejTKLU6BcRNFfCTbEY4%mdYEa3DdS@G(Z zg?CWs)S44^(Jh)B0X4%?o#da>*u7g&}(lT zNTJY}%Q!8r7p!38wS`ME6^_=W70S`h)n>1hDnvoBiB=gO+>1df2MRlvoZBb)TC!xz zhD^s$lR!ERUZz>yjS#1VB6vscI;{1vffRFdUzOm{{Q{)U5^GeCsK{fGCK6&C3Z|$5 z5!`~3iWYh)-*TTFIi^G%Rv}U7g4&nqish98;HYPPRx%WLKO9jyxLVlJs{5E3;Y04~ zJ9mZv=WY*Q_$b+2q=0KeG(1`nGF2A>i2ZHQ`_a7Dti@u#m4Qjv;U+^2b_V%&JM!Ql z)D>H2u^Qv^#1%mejv<3(d}9SYgC#Tl;>>DvIz@sji82RL@?fT9l7R$ZQ~Tf0l%i!)86by*i*n@ldFKH* z;xH?@=x2mAn4L_%EiBvdlpwE6zc!}K;VK|#0^~F#o%IUfwT~5$vx*9^-nW&ASE@PD#QxN|5 z@{@{Ise>qxu+c(gnbl%u6qPREs5 zYD-0#r6V9J%}fM^{>jC?7?H^Zz&XrI>vuEjg(CvorQQ+cUV*~7Y9=luiWLH|E-_E% zM)huhXkO>;DMF3fY(VeU4rs8_g#$;vXZ}OshLDlc`$mnx8 zQU_c^RD5~ia`+ypeBUKVl{ z8fd0t!45R>us$fR#r9EbuU_`Lqf+{Wa%y(W`G~%qv`3TF?O9tb@lU4WPd+X=Z1M}wJ2WiGB3^&`J*qJZj0 z)_&s)8LS|~7vXiM?lS@;^doY{x6VZa>R8rpjV@2>J1`?O_!b;&yaa*Ox2n#u$I*rZ ztLvy2hwr)JIe*n1K2D1pmBw(}#+4osJSbFW z?^@ogAv%MiTsNo{Brz0^ItsyRFDtpNOYXk=S_96fimT z2C3}#XsmE@|2yaow%ctscS*wlsI|ec{k*iE4QP9>m`%T3o8K3P+^&pg6)O+k002TF zj-D*jgGN>2YMbt~=>TZTyj3>|kwjU5Y=XPgRMi0$s3JiIIwK-jK#Ed=@Mmw&Z%77A zgf7#K1+^QXb_xf_1f=02k;;eC0(#b^zb}hNg`<32g88yZxqdz*bj|QvrzXMTLE>1x z0BIJYWmhYJDGg&E@>6LKjD6?cIRy3OyHJgH+{;*OaK&} zNdf<&A2Wew)Ms)3;-=?1+CZcYp;W&fHP25`cl#f(kVskmzB1^#?W-CA(?|VMHn3X# zuzTA)0{|NmPT2E|1t);mrmIei@t^~E(-TLwlP*Jpa%O_RZ*yV+z~Qu!OH_J=2TY9p zZn9>JAYwKC7`#%OO&?E3D$|JKKlJ)4fytc(!?83J64;6jWH%W}#9OHjlO8733`xZ! zm7At#tq&Y7;^U(8x~!o-DX9GjJOQAfu5POjq;E;t@4X33gTs@ltCdAhsQ<60k^ldZ mx.array: + def hz_to_mel(freq, mel_scale="htk"): + if mel_scale == "htk": + return 2595.0 * math.log10(1.0 + freq / 700.0) + + # slaney scale + f_min, f_sp = 0.0, 200.0 / 3 + mels = (freq - f_min) / f_sp + min_log_hz = 1000.0 + min_log_mel = (min_log_hz - f_min) / f_sp + logstep = math.log(6.4) / 27.0 + if freq >= min_log_hz: + mels = min_log_mel + math.log(freq / min_log_hz) / logstep + return mels + + def mel_to_hz(mels, mel_scale="htk"): + if mel_scale == "htk": + return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) + + # slaney scale + f_min, f_sp = 0.0, 200.0 / 3 + freqs = f_min + f_sp * mels + min_log_hz = 1000.0 + min_log_mel = (min_log_hz - f_min) / f_sp + logstep = math.log(6.4) / 27.0 + log_t = mels >= min_log_mel + freqs[log_t] = min_log_hz * mx.exp(logstep * (mels[log_t] - min_log_mel)) + return freqs + + f_max = f_max or sample_rate / 2 + + # generate frequency points + + n_freqs = n_fft // 2 + 1 + all_freqs = mx.linspace(0, sample_rate // 2, n_freqs) + + # convert frequencies to mel and back to hz + + m_min = hz_to_mel(f_min, mel_scale) + m_max = hz_to_mel(f_max, mel_scale) + m_pts = mx.linspace(m_min, m_max, n_mels + 2) + f_pts = mel_to_hz(m_pts, mel_scale) + + # compute slopes for filterbank + + f_diff = f_pts[1:] - f_pts[:-1] + slopes = mx.expand_dims(f_pts, 0) - mx.expand_dims(all_freqs, 1) + + # calculate overlapping triangular filters + + down_slopes = (-slopes[:, :-2]) / f_diff[:-1] + up_slopes = slopes[:, 2:] / f_diff[1:] + filterbank = mx.maximum( + mx.zeros_like(down_slopes), mx.minimum(down_slopes, up_slopes) + ) + + if norm == "slaney": + enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels]) + filterbank *= mx.expand_dims(enorm, 0) + + filterbank = filterbank.moveaxis(0, 1) + return filterbank + + +@lru_cache(maxsize=None) +def hanning(size): + return mx.array( + [0.5 * (1 - math.cos(2 * math.pi * n / (size - 1))) for n in range(size)] + ) + + +def stft(x, window, nperseg=256, noverlap=None, nfft=None, pad_mode="constant"): + if nfft is None: + nfft = nperseg + if noverlap is None: + noverlap = nfft // 4 + + def _pad(x, padding, pad_mode="constant"): + if pad_mode == "constant": + return mx.pad(x, [(padding, padding)]) + elif pad_mode == "reflect": + prefix = x[1 : padding + 1][::-1] + suffix = x[-(padding + 1) : -1][::-1] + return mx.concatenate([prefix, x, suffix]) + else: + raise ValueError(f"Invalid pad_mode {pad_mode}") + + padding = nperseg // 2 + x = _pad(x, padding, pad_mode) + + strides = [noverlap, 1] + t = (x.size - nperseg + noverlap) // noverlap + shape = [t, nfft] + x = mx.as_strided(x, shape=shape, strides=strides) + return mx.fft.rfft(x * window) + + +def istft(x, window, nperseg=256, noverlap=None, nfft=None): + if nfft is None: + nfft = nperseg + if noverlap is None: + noverlap = nfft // 4 + + t = (x.shape[0] - 1) * noverlap + nperseg + reconstructed = mx.zeros(t) + window_sum = mx.zeros(t) + + for i in range(x.shape[0]): + # inverse FFT of each frame + frame_time = mx.fft.irfft(x[i]) + + # get the position in the time-domain signal to add the frame + start = i * noverlap + end = start + nperseg + + # overlap-add the inverse transformed frame, scaled by the window + reconstructed[start:end] += frame_time * window + window_sum[start:end] += window + + # normalize by the sum of the window values + reconstructed = mx.where(window_sum != 0, reconstructed / window_sum, reconstructed) + + return reconstructed + + +def log_mel_spectrogram( + audio: mx.array, + sample_rate: int = 24_000, + n_mels: int = 100, + n_fft: int = 1024, + hop_length: int = 256, + padding: int = 0, +): + if not isinstance(audio, mx.array): + audio = mx.array(audio) + + if padding > 0: + audio = mx.pad(audio, (0, padding)) + + freqs = stft(audio, hanning(n_fft), nperseg=n_fft, noverlap=hop_length) + magnitudes = freqs[:-1, :].abs() + filters = mel_filters( + sample_rate=sample_rate, + n_fft=n_fft, + n_mels=n_mels, + norm=None, + mel_scale="htk", + ) + mel_spec = magnitudes @ filters.T + log_spec = mx.maximum(mel_spec, 1e-5).log() + return mx.expand_dims(log_spec, axis=0) diff --git a/vocos_mlx/model.py b/vocos_mlx/model.py index 64f5186..6b64ad7 100644 --- a/vocos_mlx/model.py +++ b/vocos_mlx/model.py @@ -1,7 +1,4 @@ from __future__ import annotations -from functools import lru_cache -import math -import os from pathlib import Path from typing import Any, List, Optional from types import SimpleNamespace @@ -13,103 +10,7 @@ import yaml from vocos_mlx.encodec import EncodecModel - - -@lru_cache(maxsize=None) -def mel_filters(n_mels: int) -> mx.array: - """ - load the mel filterbank matrix for projecting STFT into a Mel spectrogram. - Saved using extract_filterbank.py - """ - assert n_mels in {100}, f"Unsupported n_mels: {n_mels}" - - filename = os.path.join("assets", "mel_filters.npz") - return mx.load(filename, format="npz")[f"mel_{n_mels}"] - - -@lru_cache(maxsize=None) -def hanning(size): - return mx.array( - [0.5 * (1 - math.cos(2 * math.pi * n / (size - 1))) for n in range(size)] - ) - - -def stft(x, window, nperseg=256, noverlap=None, nfft=None, pad_mode="constant"): - if nfft is None: - nfft = nperseg - if noverlap is None: - noverlap = nfft // 4 - - def _pad(x, padding, pad_mode="constant"): - if pad_mode == "constant": - return mx.pad(x, [(padding, padding)]) - elif pad_mode == "reflect": - prefix = x[1 : padding + 1][::-1] - suffix = x[-(padding + 1) : -1][::-1] - return mx.concatenate([prefix, x, suffix]) - else: - raise ValueError(f"Invalid pad_mode {pad_mode}") - - padding = nperseg // 2 - x = _pad(x, padding, pad_mode) - - strides = [noverlap, 1] - t = (x.size - nperseg + noverlap) // noverlap - shape = [t, nfft] - x = mx.as_strided(x, shape=shape, strides=strides) - return mx.fft.rfft(x * window) - - -def istft(x, window, nperseg=256, noverlap=None, nfft=None): - if nfft is None: - nfft = nperseg - if noverlap is None: - noverlap = nfft // 4 - - t = (x.shape[0] - 1) * noverlap + nperseg - reconstructed = mx.zeros(t) - window_sum = mx.zeros(t) - - for i in range(x.shape[0]): - # inverse FFT of each frame - frame_time = mx.fft.irfft(x[i]) - - # get the position in the time-domain signal to add the frame - start = i * noverlap - end = start + nperseg - - # overlap-add the inverse transformed frame, scaled by the window - reconstructed[start:end] += frame_time * window - window_sum[start:end] += window - - # normalize by the sum of the window values - reconstructed = mx.where(window_sum != 0, reconstructed / window_sum, reconstructed) - - return reconstructed - - -def log_mel_spectrogram( - audio: mx.array, - n_mels: int = 100, - n_fft: int = 1024, - hop_length: int = 256, - padding: int = 0, - filterbank: Optional[mx.array] = None, -): - if not isinstance(audio, mx.array): - audio = mx.array(audio) - - if padding > 0: - audio = mx.pad(audio, (0, padding)) - - freqs = stft(audio, hanning(n_fft), nperseg=n_fft, noverlap=hop_length) - magnitudes = freqs[:-1, :].abs() - filters = filterbank if filterbank is not None else mel_filters(n_mels) - mel_spec = magnitudes @ filters.T - log_spec = mx.maximum(mel_spec, 1e-5).log() - return mx.expand_dims(log_spec, axis=0) - - +from vocos_mlx.mel import log_mel_spectrogram, istft, hanning class FeatureExtractor(nn.Module): """Base class for feature extractors.""" @@ -125,25 +26,24 @@ def __init__( hop_length=256, n_mels=100, padding="center", - filterbank: Optional[mx.array] = None, ): super().__init__() if padding not in ["center", "same"]: raise ValueError("Padding must be 'center' or 'same'.") self.padding = padding + self.sample_rate = sample_rate self.n_fft = n_fft self.hop_length = hop_length self.n_mels = n_mels - self.filterbank = filterbank def __call__(self, audio: mx.array, **kwargs): return log_mel_spectrogram( audio, + sample_rate=self.sample_rate, n_mels=self.n_mels, n_fft=self.n_fft, hop_length=self.hop_length, - padding=0, - filterbank=self.filterbank, + padding=0 ) @@ -352,7 +252,7 @@ def __init__( @classmethod def from_hparams( - cls, config_path: str, filterbank: Optional[mx.array] = None + cls, config_path: str ) -> Vocos: """ Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. @@ -362,8 +262,6 @@ def from_hparams( if "MelSpectrogramFeatures" in config.feature_extractor["class_path"]: feature_extractor_init_args = config.feature_extractor["init_args"] - if filterbank is not None: - feature_extractor_init_args["filterbank"] = filterbank feature_extractor = MelSpectrogramFeatures(**feature_extractor_init_args) elif "EncodecFeatures" in config.feature_extractor["class_path"]: feature_extractor = EncodecFeatures(**config.feature_extractor["init_args"]) @@ -391,17 +289,8 @@ def from_pretrained(cls, path_or_repo: str) -> Vocos: with open(model_path, "rb") as f: weights = mx.load(f) - # load the filterbank for model initialization - - try: - filterbank = weights.pop( - "feature_extractor.mel_spec.mel_scale.fb" - ).moveaxis(0, 1) - except KeyError: - filterbank = None - config_path = path / "config.yaml" - model = cls.from_hparams(config_path, filterbank) + model = cls.from_hparams(config_path) # remove unused weights try: