Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent outputs when running onnx and pytorch (stft and istft) #23219

Open
etemesi254 opened this issue Dec 28, 2024 · 0 comments
Open

Inconsistent outputs when running onnx and pytorch (stft and istft) #23219

etemesi254 opened this issue Dec 28, 2024 · 0 comments

Comments

@etemesi254
Copy link

Describe the issue

Hi, thanks for the great library :)

Asteroid-filterbank (https://github.com/asteroid-team/asteroid-filterbanks) provides an onnx exportable implementation of stft and istft operations that i am using in a model for speech separation. The stft and istft is intergrated into the model for easier end to end inference.

Exporting to onnx has some warnings (shown below) and on exporting the model generates artifacts that make the audio seem to have extra noise which is not ideal.

I am seeking help in case this is an issue on asteroid or onnx and would appreciate someone looking into it. Thanks

Error/Warning Output logs from onnx


  warnings.warn(
/miniconda3/envs/rizumu/lib/python3.11/site-packages/asteroid_filterbanks/enc_dec.py:294: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  length = min(length, wav.shape[-1])
miniconda3/envs/rizumu/lib/python3.11/site-packages/torch/onnx/_internal/jit_utils.py:308: UserWarning: Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. Constant folding not applied. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/jit/passes/onnx/constant_fold.cpp:180.)
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
miniconda3/envs/rizumu/lib/python3.11/site-packages/torch/onnx/utils.py:663: UserWarning: Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. Constant folding not applied. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/jit/passes/onnx/constant_fold.cpp:180.)
  _C._jit_pass_onnx_graph_shape_type_inference(
miniconda3/envs/rizumu/lib/python3.11/site-packages/torch/onnx/utils.py:1186: UserWarning: Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. Constant folding not applied. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/jit/passes/onnx/constant_fold.cpp:180.)
  _C._jit_pass_onnx_graph_shape_type_inference(

To reproduce

Colabarotory Link: https://colab.research.google.com/drive/1mNCwjGqMWLSAIZOIi1FJOJgfJqOmjxWn#scrollTo=H2c-2PWuNxxg

Installing dependencies

!pip install onnxruntime onnx asteroid-filterbanks

Code

from typing import Optional

import onnxruntime
import torch
from torch import nn, Tensor

from asteroid_filterbanks.enc_dec import Encoder, Decoder
from asteroid_filterbanks.transforms import to_torchaudio, from_torchaudio
from asteroid_filterbanks import torch_stft_fb


class AsteroidSTFT(nn.Module):
    def __init__(self, fb):
        super(AsteroidSTFT, self).__init__()
        self.enc = Encoder(fb)

    def forward(self, x):
        aux = self.enc(x)
        return to_torchaudio(aux)


class AsteroidISTFT(nn.Module):
    def __init__(self, fb):
        super(AsteroidISTFT, self).__init__()
        self.dec = Decoder(fb)

    def forward(self, x: Tensor, length: Optional[int] = None) -> Tensor:
        aux = from_torchaudio(x)
        x = self.dec(aux, length=length)
        return x


def make_filterbanks(n_fft=4096, n_hop=1024, center=True, sample_rate=44100.0):
    window = nn.Parameter(torch.hann_window(n_fft), requires_grad=False)

    fb = torch_stft_fb.TorchSTFTFB.from_torch_args(
        n_fft=n_fft,
        hop_length=n_hop,
        win_length=n_fft,
        window=window,
        center=center,
        sample_rate=sample_rate,
    )
    encoder = AsteroidSTFT(fb)
    decoder = AsteroidISTFT(fb)

    return encoder, decoder



class TempTest(nn.Module):
    def __init__(self):
        super(TempTest, self).__init__()

        self.stft,self.istft = make_filterbanks()

    def forward(self, x: Tensor) -> Tensor:
        initial_size = x.shape[-1]
        was_unsqueezed = False

        if x.ndim == 2:
            # stft expects (batch, audio,channel) while model takes audio,channel
            # so fake a third dimension
            x = x.unsqueeze(0)
            was_unsqueezed = True
        prev_device = x.device
        x_cpu = x.to("cpu")
        self.stft = self.stft.to("cpu")
        x = self.stft(x_cpu)
        x = self.istft(x,initial_size)
        # return back to previous device
        x = x.to(prev_device)

        if was_unsqueezed:
            # remove the fake dimension squeeze
            x = x.squeeze(dim=0)
        return x

if __name__ == '__main__':
    model = TempTest()
    audio = torch.randn((1,20000))
    c = model(audio)
    torch.testing.assert_close(c,audio)
    # export to onnx
    torch.onnx.export(model,audio,"./temp_test.onnx",
                  dynamo_export=True,
                  external_data=False,
                  report=True,
                  verify=True,

                  input_names=["input"],
                  output_names=["output"],
                  dynamic_axes={"input": {0: "channels", 1: "length"},
                                "output": {0: "channels", 1: "length"}})
    sess = onnxruntime.InferenceSession("./temp_test.onnx")
    output = sess.run(["output"], {"input": audio.detach().numpy()})[0]
    torch.testing.assert_close(torch.from_numpy(output),audio)

Urgency

No response

Platform

Mac

OS Version

15.0 (24A335)

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

onnx==1.17.0 onnxruntime==1.20.1

ONNX Runtime API

Python

Architecture

ARM64

Execution Provider

Default CPU

Execution Provider Library Version

No response

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant