-
Notifications
You must be signed in to change notification settings - Fork 672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Modified Discrete Cosine Transform (MDCT) #2696
Comments
Hi @Kinyugo Welcome to torchaudio project, and thanks for the proposal. Here is the general steps to add it to torchaudio.
*If you are not used to writing a test, let us know, we can take over. Notes:
Now about the implementation;
|
Also please refer to CONTRIBUTING.md for setting up the development environment. |
@mthrok Thank you for the guidelines on how to contribute. About the implementation:
Let me know of your thoughts and suggestions on how we can go about it. |
I have also managed to port a 1 to 1 implementation of the inverse. As it is supposed to be an invertible operation we can check the reconstruction error to verify some correctness. Though I am not sure about the mathematical correctness of the implementation. import math
from typing import Callable, Optional, Union
import torch
from torch import Tensor
from torch.nn import functional as F
def mdct(x: Tensor,
window: Union[Callable[..., Tensor], Tensor],
window_length: int,
hop_length: Optional[int] = None,
center: bool = True,
pad_mode: str = "constant") -> Tensor:
# Initialize the window
if callable(window):
window = window(window_length)
if hop_length is None:
hop_length = window_length // 2
# Flatten the input tensor
shape = x.shape
x = x.reshape((-1, shape[-1]))
# Derive the number of frequencies and frames
n_freqs = window_length // 2
n_frames = int(math.ceil(shape[-1] / hop_length)) + 1
# Center pad the signal
if center:
x = F.pad(x, (hop_length, hop_length), mode=pad_mode)
# Initialize the mdct
x_mdct = torch.zeros((x.shape[0], n_freqs, n_frames), device=x.device)
# Prepare the pre&post processing
preprocess_arr = torch.exp(
-1j * torch.pi / window_length *
torch.arange(0, window_length, device=x.device)).unsqueeze(0)
postprocess_arr = torch.exp(
-1j * torch.pi / window_length * (window_length / 2 + 1) *
torch.arange(0.5, window_length / 2 + 0.5,
device=x.device)).unsqueeze(0)
# Loop over time frames
i = 0
for j in range(n_frames):
# Window the signal
x_segment = x[:, i:i + window_length] * window
i = i + hop_length
# Compute the fourier transform of the windowed signal
x_segment = torch.fft.fft(x_segment * preprocess_arr)
x_mdct[:, :, j] = torch.real(x_segment[:, :n_freqs] * postprocess_arr)
x_mdct = x_mdct.reshape(shape[:-1] + x_mdct.shape[-2:])
return x_mdct
def imdct(x: Tensor,
window: Union[Callable[..., Tensor], Tensor],
window_length: int,
hop_length: Optional[int] = None,
center: bool = True,
pad_mode: str = "constant") -> Tensor:
# Initialize the window
if callable(window):
window = window(window_length)
if hop_length is None:
hop_length = window_length // 2
# Flatten the input tensor
shape = x.shape
n_freqs, n_frames = x.shape[-2:]
x = x.reshape((-1, n_freqs, n_frames))
# Derive the number of samples
n_samples = hop_length * (n_frames + 1)
# Initialize the signal
x_imdct = torch.zeros((x.shape[0], n_samples), device=x.device)
# Prepare the pre&post processing
preprocess_arr = (torch.exp(
-1j * torch.pi / (2 * n_freqs) * (n_freqs + 1) *
torch.arange(0, n_freqs, device=x.device))).unsqueeze(dim=-1)
postprocess_arr = (torch.exp(-1j * torch.pi / (2 * n_freqs) * torch.arange(
0.5 + n_freqs / 2, 2 * n_freqs + n_freqs / 2 + 0.5)) /
n_freqs).unsqueeze(dim=-1)
x = torch.fft.fft(x * preprocess_arr, n=2 * n_freqs, axis=1)
# Apply the window function to the frames after post-processing
x = 2 * (torch.real(x * postprocess_arr) * window.unsqueeze(dim=-1))
# Loop over the time frames
i = 0
for j in range(n_frames):
# Recover the signal with the time-domain aliasing cancelling principle
x_imdct[:, i:i +
window_length] = x_imdct[:, i:i + window_length] + x[:, :, j]
i = i + hop_length
# Remove padding
if center:
x_imdct = x_imdct[:, hop_length:-hop_length]
x_imdct = x_imdct.reshape((*shape[:-2], -1))
return x_imdct
x = torch.randn(1, 2, 131_072)
window_length = 1024
w = torch.sin(torch.pi / 2 * pow(
torch.sin(torch.pi / window_length *
torch.arange(0.5, window_length + 0.5)), 2)) # vorbis window
y = mdct(x, w, window_length)
z = imdct(y, w, window_length)
loss = torch.mean(torch.abs(x - z))
print(loss) # Around 1e-5 I am yet to figure out how to optimize the loops. |
@mthrok I managed to get a working vectorized implementation of the MDCT algorithm. However, I have trouble setting up the development environment on my local machine to open a PR. How do I go about it? |
Hi @Kinyugo, you can refer to Contributing.md to get started with dev environment set up and process. Feel free to let us know if you encounter any errors while doing so! |
Hello @carolineechen. I followed the steps in the contributing guide but ran into some problems installing the package i.e: subprocess.CalledProcessError: Command '['cmake', '--build', '.', '--target', 'install']' returned non-zero exit status 1. |
Hi @Kinyugo, can you post the full error message, as well as your environment, so we can get a better idea of what the error is? To get the environment versions, you can use
|
Full Error Message
Environment
|
could you try downgrading mkl to version 2021.2.0, and then running
|
The issue is still persistent. |
Hi @Kinyugo, sorry for the late reply, looks like this might be related to #2784, and there's a potential solution offered there. We also have a team member looking into this, hopefully it can be resolved soon and we'll update this issue when it is. As you already have a working version offline, if you would like, it might also be possible for you to create a draft PR anyways in the meantime [but harder to verify if tests compile/pass locally], and we could provide some preliminary context. |
Hello @carolineechen, I will look into creating a draft PR, as I would also appreciate some feedback on some design choices that I made. Godspeed on resolving the issue. Thanks for your support. |
Hi @mthrok & @carolineechen, I apologize for taking too long to get back to this. I have made an implementation of the MDCT algorithm in PyTorch here. Kindly let me know your thoughts. I will open a draft PR soon. Kindly help me out with the testing as I cannot setup the development environment on my machine. I will appreciate any help or feedback. Thanks for your support. |
cc @nateanl |
Not to hijaak the thread, but @Kinyugo what are your thoughts on PQMF? |
🚀 The feature
The Modified Discrete Cosine Transform (MDCT) is a perfectly invertible transform that can be used for feature extraction. It can be used as an alternative to MelSpectrograms especially where an invertible transform is desired such as audio synthesis in a more compressed space.
Motivation, pitch
I am working on a audio synthesis project and an invertible transformation that is desirable in this case as opposed to something like a
MelSpectrogram
. The MDCT is a viable alternative. However, there is no implementation of it in torchaudio.Alternatives
I have tried working with
MelSpectrogram
transform. However, pitch reconstruction is cumbersome and requires implementation of complex neural vocoders which is undesirable in my case.Additional context
There is a numpy implementation Zaf as well as a pypi package mdct. I would like to assist in the porting of these implementations to torchaudio. However, I require some guidance on how to go about it. Any help would be appreciated.
Currently I have a naive implementation of a 1 to 1 copy of the Zaf implementation in pytorch. I think there is room for a lot of optimization.
The text was updated successfully, but these errors were encountered: