From 891ca4fd39298227b23dc5af7a385ca2fe6918f4 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 27 Oct 2024 13:33:25 +0800 Subject: [PATCH] make buffers non-persistent --- .../prototype/transforms/_transforms.py | 6 ++--- src/torchaudio/transforms/_transforms.py | 25 ++++++++++--------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/torchaudio/prototype/transforms/_transforms.py b/src/torchaudio/prototype/transforms/_transforms.py index 9d89cc5339..68be2c00d5 100644 --- a/src/torchaudio/prototype/transforms/_transforms.py +++ b/src/torchaudio/prototype/transforms/_transforms.py @@ -55,7 +55,7 @@ def __init__( raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max)) fb = barkscale_fbanks(n_stft, self.f_min, self.f_max, self.n_barks, self.sample_rate, self.bark_scale) - self.register_buffer("fb", fb) + self.register_buffer("fb", fb, persistent=False) def forward(self, specgram: torch.Tensor) -> torch.Tensor: r""" @@ -138,7 +138,7 @@ def __init__( raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max)) fb = barkscale_fbanks(n_stft, self.f_min, self.f_max, self.n_barks, self.sample_rate, bark_scale) - self.register_buffer("fb", fb) + self.register_buffer("fb", fb, persistent=False) def forward(self, barkspec: torch.Tensor) -> torch.Tensor: r""" @@ -343,7 +343,7 @@ def __init__( fb = chroma_filterbank( sample_rate, n_freqs, n_chroma, tuning=tuning, ctroct=ctroct, octwidth=octwidth, norm=norm, base_c=base_c ) - self.register_buffer("fb", fb) + self.register_buffer("fb", fb, persistent=False) def forward(self, x: torch.Tensor) -> torch.Tensor: r""" diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 802cbd3d77..aafadcd1ce 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -83,7 +83,7 @@ def __init__( self.win_length = win_length if win_length is not None else n_fft self.hop_length = hop_length if hop_length is not None else self.win_length // 2 window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) - self.register_buffer("window", window) + self.register_buffer("window", window, persistent=False) self.pad = pad self.power = power self.normalized = normalized @@ -177,7 +177,7 @@ def __init__( self.win_length = win_length if win_length is not None else n_fft self.hop_length = hop_length if hop_length is not None else self.win_length // 2 window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) - self.register_buffer("window", window) + self.register_buffer("window", window, persistent=False) self.pad = pad self.normalized = normalized self.center = center @@ -266,7 +266,7 @@ def __init__( self.win_length = win_length if win_length is not None else n_fft self.hop_length = hop_length if hop_length is not None else self.win_length // 2 window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) - self.register_buffer("window", window) + self.register_buffer("window", window, persistent=False) self.length = length self.power = power self.momentum = momentum @@ -397,7 +397,7 @@ def __init__( raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max)) fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale) - self.register_buffer("fb", fb) + self.register_buffer("fb", fb, persistent=False) def forward(self, specgram: Tensor) -> Tensor: r""" @@ -476,7 +476,7 @@ def __init__( raise ValueError(f'driver must be one of ["gels", "gelsy", "gelsd", "gelss"]. Found {driver}.') fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm, mel_scale) - self.register_buffer("fb", fb) + self.register_buffer("fb", fb, persistent=False) def forward(self, melspec: Tensor) -> Tensor: r""" @@ -685,7 +685,7 @@ def __init__( if self.n_mfcc > self.MelSpectrogram.n_mels: raise ValueError("Cannot select more MFCC coefficients than # mel bins") dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm) - self.register_buffer("dct_mat", dct_mat) + self.register_buffer("dct_mat", dct_mat, persistent=False) self.log_mels = log_mels def forward(self, waveform: Tensor) -> Tensor: @@ -788,10 +788,10 @@ def __init__( n_filter=self.n_filter, sample_rate=self.sample_rate, ) - self.register_buffer("filter_mat", filter_mat) + self.register_buffer("filter_mat", filter_mat, persistent=False) dct_mat = F.create_dct(self.n_lfcc, self.n_filter, self.norm) - self.register_buffer("dct_mat", dct_mat) + self.register_buffer("dct_mat", dct_mat, persistent=False) self.log_lf = log_lf def forward(self, waveform: Tensor) -> Tensor: @@ -964,7 +964,7 @@ def __init__( beta, dtype=dtype, ) - self.register_buffer("kernel", kernel) + self.register_buffer("kernel", kernel, persistent=False) def forward(self, waveform: Tensor) -> Tensor: r""" @@ -1051,7 +1051,8 @@ def __init__(self, hop_length: Optional[int] = None, n_freq: int = 201, fixed_ra n_fft = (n_freq - 1) * 2 hop_length = hop_length if hop_length is not None else n_fft // 2 - self.register_buffer("phase_advance", torch.linspace(0, math.pi * hop_length, n_freq)[..., None]) + phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None] + self.register_buffer("phase_advance", phase_advance, persistent=False) def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor: r""" @@ -1652,7 +1653,7 @@ def __init__( self.win_length = win_length if win_length is not None else n_fft self.hop_length = hop_length if hop_length is not None else self.win_length // 2 window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) - self.register_buffer("window", window) + self.register_buffer("window", window, persistent=False) self.pad = pad def forward(self, waveform: Tensor) -> Tensor: @@ -1717,7 +1718,7 @@ def __init__( self.win_length = win_length if win_length is not None else n_fft self.hop_length = hop_length if hop_length is not None else self.win_length // 4 window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) - self.register_buffer("window", window) + self.register_buffer("window", window, persistent=False) rate = 2.0 ** (-float(n_steps) / bins_per_octave) self.orig_freq = int(sample_rate / rate) self.gcd = math.gcd(int(self.orig_freq), int(sample_rate))