Skip to content

Commit

Permalink
[DLTP-45176] Add complex compatibility in static mode for stft api.
Browse files Browse the repository at this point in the history
  • Loading branch information
KPatr1ck committed Mar 16, 2022
1 parent 0b07d62 commit 07430af
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions python/paddle/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,10 @@ def istft(x,
assert len(window.shape) == 1 and len(window) == win_length, \
'expected a 1D window tensor of size equal to win_length({}), but got window with shape {}.'.format(win_length, window.shape)
else:
window = paddle.ones(shape=(win_length, ))
window_dtype = paddle.float32 if x.dtype in [
paddle.float32, paddle.complex64
] else paddle.float64
window = paddle.ones(shape=(win_length, ), dtype=window_dtype)

if win_length < n_fft:
pad_left = (n_fft - win_length) // 2
Expand Down Expand Up @@ -536,11 +539,10 @@ def istft(x,
x = x[:, :, :n_fft // 2 + 1]
out = fft_c2r(x=x, n=None, axis=-1, norm=norm, forward=False, name=None)

out = paddle.multiply(out, window).transpose(
perm=[0, 2, 1]) # (batch, n_fft, num_frames)
out = overlap_add(
x=paddle.multiply(out, window).transpose(
perm=[0, 2, 1]), # (batch, n_fft, num_frames)
hop_length=hop_length,
axis=-1) # (batch, seq_length)
x=out, hop_length=hop_length, axis=-1) # (batch, seq_length)

window_envelop = overlap_add(
x=paddle.tile(
Expand Down

0 comments on commit 07430af

Please sign in to comment.