Skip to content

Commit f658b6f

Browse files
authored
Score-based generative enhancement model (#8567)
* Score-based generative enhancement model in NeMo * Addressed comments, added unit test Signed-off-by: Ante Jukić <[email protected]>
1 parent 3d87ed7 commit f658b6f

21 files changed

+2985
-349
lines changed

examples/audio_tasks/audio_to_audio_eval.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import json
6262
import os
6363
import tempfile
64+
from collections import defaultdict
6465
from dataclasses import dataclass, field, is_dataclass
6566
from typing import List, Optional
6667

@@ -101,6 +102,9 @@ class AudioEvaluationConfig(process_audio.ProcessConfig):
101102
# Metrics to calculate
102103
metrics: List[str] = field(default_factory=lambda: ['sdr', 'estoi'])
103104

105+
# Return metric values for each example
106+
return_values_per_example: bool = False
107+
104108

105109
def get_evaluation_dataloader(config):
106110
"""Prepare a dataloader for evaluation.
@@ -174,6 +178,9 @@ def main(cfg: AudioEvaluationConfig):
174178
# Setup metrics
175179
metrics = get_metrics(cfg)
176180

181+
if cfg.return_values_per_example and cfg.batch_size > 1:
182+
raise ValueError('return_example_values is only supported for batch_size=1.')
183+
177184
# Processing
178185
if not cfg.only_score_manifest:
179186
# Process audio using the configured model and save in the output directory
@@ -236,6 +243,10 @@ def main(cfg: AudioEvaluationConfig):
236243

237244
num_files += 1
238245

246+
if cfg.max_utts is not None and num_files >= cfg.max_utts:
247+
logging.info('Reached max_utts: %s', cfg.max_utts)
248+
break
249+
239250
# Prepare dataloader
240251
config = {
241252
'manifest_filepath': temporary_manifest_filepath,
@@ -249,6 +260,8 @@ def main(cfg: AudioEvaluationConfig):
249260
}
250261
temporary_dataloader = get_evaluation_dataloader(config)
251262

263+
metrics_value_per_example = defaultdict(list)
264+
252265
# Calculate metrics
253266
for eval_batch in tqdm(temporary_dataloader, desc='Evaluating'):
254267
processed_signal, processed_length, target_signal, target_length = eval_batch
@@ -257,7 +270,9 @@ def main(cfg: AudioEvaluationConfig):
257270
raise RuntimeError(f'Length mismatch.')
258271

259272
for name, metric in metrics.items():
260-
metric.update(preds=processed_signal, target=target_signal, input_length=target_length)
273+
value = metric(preds=processed_signal, target=target_signal, input_length=target_length)
274+
if cfg.return_values_per_example:
275+
metrics_value_per_example[name].append(value.item())
261276

262277
# Convert to a dictionary with name: value
263278
metrics_value = {name: metric.compute().item() for name, metric in metrics.items()}
@@ -277,6 +292,7 @@ def main(cfg: AudioEvaluationConfig):
277292
# Inject the metric name and score into the config, and return the entire config
278293
with open_dict(cfg):
279294
cfg.metrics_value = metrics_value
295+
cfg.metrics_value_per_example = dict(metrics_value_per_example)
280296

281297
return cfg
282298

examples/audio_tasks/conf/beamforming.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ model:
4444
_target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram
4545
fft_length: 512 # Length of the window and FFT for calculating spectrogram
4646
hop_length: 256 # Hop length for calculating spectrogram
47-
power: null
4847

4948
decoder:
5049
_target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio

examples/audio_tasks/conf/masking.yaml

-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# This configuration contains the exemplary values for training a multichannel speech enhancement model with a mask-based beamformer.
2-
#
31
name: "masking"
42

53
model:
@@ -44,7 +42,6 @@ model:
4442
_target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram
4543
fft_length: 512 # Length of the window and FFT for calculating spectrogram
4644
hop_length: 256 # Hop length for calculating spectrogram
47-
power: null
4845

4946
decoder:
5047
_target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio
+130
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
name: "predictive_model"
2+
3+
model:
4+
type: predictive
5+
sample_rate: 16000
6+
skip_nan_grad: false
7+
num_outputs: 1
8+
normalize_input: true # normalize the input signal to 0dBFS
9+
10+
train_ds:
11+
manifest_filepath: ???
12+
input_key: noisy_filepath
13+
target_key: clean_filepath
14+
audio_duration: 2.04 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 256
15+
random_offset: true
16+
normalization_signal: input_signal
17+
batch_size: 8 # batch size may be increased based on the available memory
18+
shuffle: true
19+
num_workers: 8
20+
pin_memory: true
21+
22+
validation_ds:
23+
manifest_filepath: ???
24+
input_key: noisy_filepath
25+
target_key: clean_filepath
26+
batch_size: 8
27+
shuffle: false
28+
num_workers: 4
29+
pin_memory: true
30+
31+
encoder:
32+
_target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram
33+
fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256
34+
hop_length: 128
35+
magnitude_power: 0.5
36+
scale: 0.33
37+
38+
decoder:
39+
_target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio
40+
fft_length: ${model.encoder.fft_length}
41+
hop_length: ${model.encoder.hop_length}
42+
magnitude_power: ${model.encoder.magnitude_power}
43+
scale: ${model.encoder.scale}
44+
45+
estimator:
46+
_target_: nemo.collections.asr.parts.submodules.diffusion.SpectrogramNoiseConditionalScoreNetworkPlusPlus
47+
in_channels: 1 # single-channel noisy input
48+
out_channels: 1 # single-channel estimate
49+
num_res_blocks: 3 # increased number of res blocks
50+
pad_time_to: 64 # pad to 64 frames for the time dimension
51+
pad_dimension_to: 0 # no padding in the frequency dimension
52+
53+
loss:
54+
_target_: nemo.collections.asr.losses.MSELoss # computed in the time domain
55+
56+
metrics:
57+
val:
58+
sisdr: # output SI-SDR
59+
_target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio
60+
61+
optim:
62+
name: adam
63+
lr: 1e-4
64+
# optimizer arguments
65+
betas: [0.9, 0.999]
66+
weight_decay: 0.0
67+
68+
trainer:
69+
devices: -1 # number of GPUs, -1 would use all available GPUs
70+
num_nodes: 1
71+
max_epochs: -1
72+
max_steps: -1 # computed at runtime if not set
73+
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
74+
accelerator: auto
75+
strategy: ddp
76+
accumulate_grad_batches: 1
77+
gradient_clip_val: null
78+
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
79+
log_every_n_steps: 25 # Interval of logging.
80+
enable_progress_bar: true
81+
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
82+
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
83+
sync_batchnorm: true
84+
enable_checkpointing: false # Provided by exp_manager
85+
logger: false # Provided by exp_manager
86+
87+
exp_manager:
88+
exp_dir: null
89+
name: ${name}
90+
91+
# use exponential moving average for model parameters
92+
ema:
93+
enable: true
94+
decay: 0.999 # decay rate
95+
cpu_offload: false # offload EMA parameters to CPU to save GPU memory
96+
every_n_steps: 1 # how often to update EMA weights
97+
validate_original_weights: False # use original weights for validation calculation?
98+
99+
# logging
100+
create_tensorboard_logger: true
101+
102+
# checkpointing
103+
create_checkpoint_callback: true
104+
checkpoint_callback_params:
105+
# in case of multiple validation sets, first one is used
106+
monitor: val_sisdr
107+
mode: max
108+
save_top_k: 5
109+
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints
110+
111+
# early stopping
112+
create_early_stopping_callback: true
113+
early_stopping_callback_params:
114+
monitor: val_sisdr
115+
mode: max
116+
min_delta: 0.0
117+
patience: 20 # patience in terms of check_val_every_n_epoch
118+
verbose: true
119+
strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training.
120+
121+
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
122+
# you need to set these two to true to continue the training
123+
resume_if_exists: false
124+
resume_ignore_no_checkpoint: false
125+
126+
# You may use this section to create a W&B logger
127+
create_wandb_logger: false
128+
wandb_logger_kwargs:
129+
name: null
130+
project: null
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
name: score_based_generative_model
2+
3+
model:
4+
type: score_based
5+
sample_rate: 16000
6+
skip_nan_grad: false
7+
num_outputs: 1
8+
normalize_input: true
9+
max_utts_evaluation_metrics: 50 # metric calculation needs full inference and is slow, so we limit to first few files
10+
11+
train_ds:
12+
manifest_filepath: ???
13+
input_key: noisy_filepath
14+
target_key: clean_filepath
15+
audio_duration: 2.04 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 256
16+
random_offset: true
17+
normalization_signal: input_signal
18+
batch_size: 8 # batch size may be increased based on the available memory
19+
shuffle: true
20+
num_workers: 8
21+
pin_memory: true
22+
23+
validation_ds:
24+
manifest_filepath: ???
25+
input_key: noisy_filepath
26+
target_key: clean_filepath
27+
normalize_input: false # load data as is for validation, the model will normalize it for inference
28+
batch_size: 4
29+
shuffle: false
30+
num_workers: 4
31+
pin_memory: true
32+
33+
encoder:
34+
_target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram
35+
fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256
36+
hop_length: 128
37+
magnitude_power: 0.5
38+
scale: 0.33
39+
40+
decoder:
41+
_target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio
42+
fft_length: ${model.encoder.fft_length}
43+
hop_length: ${model.encoder.hop_length}
44+
magnitude_power: ${model.encoder.magnitude_power}
45+
scale: ${model.encoder.scale}
46+
47+
estimator:
48+
_target_: nemo.collections.asr.parts.submodules.diffusion.SpectrogramNoiseConditionalScoreNetworkPlusPlus
49+
in_channels: 2 # concatenation of single-channel perturbed and noisy
50+
out_channels: 1 # single-channel score estimate
51+
conditioned_on_time: true
52+
num_res_blocks: 3 # increased number of res blocks
53+
pad_time_to: 64 # pad to 64 frames for the time dimension
54+
pad_dimension_to: 0 # no padding in the frequency dimension
55+
56+
sde:
57+
_target_: nemo.collections.asr.parts.submodules.diffusion.OrnsteinUhlenbeckVarianceExplodingSDE
58+
stiffness: 1.5
59+
std_min: 0.05
60+
std_max: 0.5
61+
num_steps: 1000
62+
63+
sampler:
64+
_target_: nemo.collections.asr.parts.submodules.diffusion.PredictorCorrectorSampler
65+
predictor: reverse_diffusion
66+
corrector: annealed_langevin_dynamics
67+
num_steps: 50
68+
num_corrector_steps: 1
69+
snr: 0.5
70+
71+
loss:
72+
_target_: nemo.collections.asr.losses.MSELoss
73+
ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time)
74+
75+
metrics:
76+
val:
77+
sisdr: # output SI-SDR
78+
_target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio
79+
80+
optim:
81+
name: adam
82+
lr: 1e-4
83+
# optimizer arguments
84+
betas: [0.9, 0.999]
85+
weight_decay: 0.0
86+
87+
trainer:
88+
devices: -1 # number of GPUs, -1 would use all available GPUs
89+
num_nodes: 1
90+
max_epochs: -1
91+
max_steps: -1 # computed at runtime if not set
92+
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
93+
accelerator: auto
94+
strategy: ddp
95+
accumulate_grad_batches: 1
96+
gradient_clip_val: null
97+
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
98+
log_every_n_steps: 25 # Interval of logging.
99+
enable_progress_bar: true
100+
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
101+
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
102+
sync_batchnorm: true
103+
enable_checkpointing: false # Provided by exp_manager
104+
logger: false # Provided by exp_manager
105+
106+
exp_manager:
107+
exp_dir: null
108+
name: ${name}
109+
110+
# use exponential moving average for model parameters
111+
ema:
112+
enable: true
113+
decay: 0.999 # decay rate
114+
cpu_offload: false # offload EMA parameters to CPU to save GPU memory
115+
every_n_steps: 1 # how often to update EMA weights
116+
validate_original_weights: false # use original weights for validation calculation?
117+
118+
# logging
119+
create_tensorboard_logger: true
120+
121+
# checkpointing
122+
create_checkpoint_callback: true
123+
checkpoint_callback_params:
124+
# in case of multiple validation sets, first one is used
125+
monitor: val_sisdr
126+
mode: max
127+
save_top_k: 5
128+
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints
129+
130+
# early stopping
131+
create_early_stopping_callback: true
132+
early_stopping_callback_params:
133+
monitor: val_sisdr
134+
mode: max
135+
min_delta: 0.0
136+
patience: 20 # patience in terms of check_val_every_n_epoch
137+
verbose: true
138+
strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training.
139+
140+
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
141+
# you need to set these two to true to continue the training
142+
resume_if_exists: false
143+
resume_ignore_no_checkpoint: false
144+
145+
# You may use this section to create a W&B logger
146+
create_wandb_logger: false
147+
wandb_logger_kwargs:
148+
name: null
149+
project: null

0 commit comments

Comments
 (0)