This repository has been archived by the owner on Sep 22, 2024. It is now read-only.
forked from pablodz/DNP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
124 lines (102 loc) · 4.25 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import soundfile as sf
import torch
import os
from scipy.io.wavfile import write
import numpy as np
import librosa
from scipy.special import expi
import librosa.display
import matplotlib.pyplot as plt
MAX_WAV_VALUE = 32768.0
class Accumulator:
def __init__(self, noisy_audio, low_cut, high_cut, nfft, center, residual, sr, bandpass):
self.low_cut = low_cut
self.high_cut = high_cut
self.center = center
self.nfft = nfft
self.stft_noisy = torch_stft(noisy_audio, nfft=nfft, center=center)
self.residual = residual
self.sr = sr
self.bandpass = bandpass
self.stft_noisy_filt = self.filter_stft()
self.stft_diff_sum = np.zeros(self.stft_noisy.shape)
def filter_stft(self):
stft_full = self.stft_noisy + 0
stft_full[:self.bandpass, :] = 0 * stft_full[:self.bandpass, :] # reduce low frequencies
stft_full[-self.bandpass // 3:, :] = 0 * stft_full[-self.bandpass // 3:, :] # reduce high frequencies
return stft_full
def sum_difference(self, stft, iter_num):
if iter_num < 48:
self.stft_prev, self.stft_minus = stft, stft
else:
self.stft_minus = np.abs(stft - self.stft_prev)/(stft+np.finfo(float).eps)
self.stft_diff_sum += self.stft_minus
self.stft_diff_sum[self.stft_diff_sum < np.percentile(self.stft_diff_sum, self.low_cut)] = \
np.percentile(self.stft_diff_sum, self.low_cut)
self.stft_diff_sum[self.stft_diff_sum > np.percentile(self.stft_diff_sum, self.high_cut)] = \
np.percentile(self.stft_diff_sum, self.high_cut)
self.stft_prev = stft
def create_atten_map(self):
max_mask = self.stft_diff_sum.max()
min_mask = self.stft_diff_sum.min()
atten_map = (max_mask - self.stft_diff_sum) / (max_mask - min_mask)
atten_map[atten_map < self.residual] = self.residual
self.atten_map = atten_map
def mmse_lsa(self):
gamma_mat = (1 - self.atten_map) ** 2
gamma_mat[gamma_mat < 10**-10] = 10**-10
gamma_mat = 1 / gamma_mat
lsa_mask = np.zeros(gamma_mat.shape)
for it, gamma in enumerate(gamma_mat.transpose()):
eta = gamma - 1
eta[eta < self.residual] = self.residual
v = gamma * eta / (1 + eta)
gain = np.ones(gamma.shape)
idx = v > 5
gain[idx] = eta[idx] / (1 + eta[idx])
idx = np.logical_and(v <= 5, v > 0) # and v > 0
gain[idx] = eta[idx] / (1 + eta[idx]) * np.exp(0.5 * -expi(-v[idx]))
gain[gain > 1] = 1
lsa_mask[:, it] = gain
self.lsa_mask = lsa_mask
def show_lsa(self):
plot_stft(self.lsa_mask, 'LSA Mask')
def show_wiener(self):
plot_stft(self.lsa_mask, 'Wiener Mask')
def show_diff_accum(self):
plot_stft(self.stft_diff_sum, 'Accumulator')
def show_diff_stft(self):
plot_stft(self.stft_minus, 'Abs Difference')
def show_noisy(self):
plot_stft(np.abs(self.stft_noisy), 'Power Spectrogram of Noisy Sample')
def plot_stft(D, title=''):
librosa.display.specshow(librosa.amplitude_to_db(D, ref=np.max), y_axis='log')
plt.title(title)
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()
plt.show()
def load_wav_to_torch(full_path):
"""
Loads wavdata into torch array
"""
data, sampling_rate = sf.read(full_path, dtype='int16')
if len(data.shape) > 1:
print('Mixing the input wav to a mono file')
data = data.mean(axis=1)
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
def write_norm_music(input, filename, sr):
audio = MAX_WAV_VALUE * input
wavdata = audio.astype('int16')
write(filename, sr, wavdata)
def torch_stft(audio, nfft=2048, center=True):
in_numpy = audio.clone().detach()
in_numpy = in_numpy.detach().cpu().numpy()
stft = librosa.stft(in_numpy, n_fft=nfft, center=center)
return stft
def write_music_stft(stft, filename, sr, center=True):
in_numpy = librosa.istft(stft, center=center)
write_norm_music(in_numpy, filename, sr)
def makedirs(outputs_dir):
if not os.path.exists(outputs_dir):
print("Creating directory: {}".format(outputs_dir))
os.makedirs(outputs_dir)