Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ASR] Multichannel mask estimator with flex number of channels #7317

Merged
merged 2 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions examples/audio_tasks/conf/beamforming_flex_channels.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# This configuration contains the exemplary values for training a multichannel speech enhancement model with a mask-based beamformer.
#
name: beamforming_flex_channels

model:
sample_rate: 16000
skip_nan_grad: false
num_outputs: 1

train_ds:
manifest_filepath: ???
input_key: audio_filepath # key of the input signal path in the manifest
input_channel_selector: null # load all channels from the input file
target_key: target_anechoic_filepath # key of the target signal path in the manifest
target_channel_selector: 0 # load only the first channel from the target file
audio_duration: 4.0 # in seconds, audio segment duration for training
random_offset: true # if the file is longer than audio_duration, use random offset to select a subsegment
min_duration: ${model.train_ds.audio_duration}
batch_size: 16 # batch size may be increased based on the available memory
shuffle: true
num_workers: 16
pin_memory: true

validation_ds:
manifest_filepath: ???
input_key: audio_filepath # key of the input signal path in the manifest
input_channel_selector: null # load all channels from the input file
target_key: target_anechoic_filepath # key of the target signal path in the manifest
target_channel_selector: 0 # load only the first channel from the target file
batch_size: 8
shuffle: false
num_workers: 8
pin_memory: true

channel_augment:
_target_: nemo.collections.asr.parts.submodules.multichannel_modules.ChannelAugment
num_channels_min: 2 # minimal number of channels selected for each batch
num_channels_max: null # max number of channels is determined by the batch size
permute_channels: true

encoder:
_target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram
fft_length: 512 # Length of the window and FFT for calculating spectrogram
hop_length: 256 # Hop length for calculating spectrogram

decoder:
_target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio
fft_length: ${model.encoder.fft_length}
hop_length: ${model.encoder.hop_length}

mask_estimator:
_target_: nemo.collections.asr.modules.audio_modules.MaskEstimatorFlexChannels
num_outputs: ${model.num_outputs} # number of output masks
num_subbands: 257 # number of subbands for the input spectrogram
num_blocks: 5 # number of blocks in the model
channel_reduction_position: 3 # 0-indexed, apply channel reduction before this block
channel_reduction_type: average # channel-wise reduction
channel_block_type: transform_average_concatenate # channel block
temporal_block_type: conformer_encoder # temporal block
temporal_block_num_layers: 5 # number of layers for the temporal block
temporal_block_num_heads: 4 # number of heads for the temporal block
temporal_block_dimension: 128 # the hidden size of the temporal block
mag_reduction: null # channel-wise reduction of magnitude
mag_normalization: mean_var # normalization using mean and variance
use_ipd: true # use inter-channel phase difference
ipd_normalization: mean # mean normalization

mask_processor:
# Mask-based multi-channel processor
_target_: nemo.collections.asr.modules.audio_modules.MaskBasedBeamformer
filter_type: pmwf # parametric multichannel wiener filter
filter_beta: 0.0 # mvdr
filter_rank: one
ref_channel: max_snr # select reference channel by maximizing estimated SNR
ref_hard: 1 # a one-hot reference. If false, a soft estimate across channels is used.
ref_hard_use_grad: false # use straight-through gradient when using hard reference
ref_subband_weighting: false # use subband weighting for reference estimation
num_subbands: ${model.mask_estimator.num_subbands}

loss:
_target_: nemo.collections.asr.losses.SDRLoss
convolution_invariant: true # convolution-invariant loss
sdr_max: 30 # soft threshold for SDR

metrics:
val:
sdr_0:
_target_: torchmetrics.audio.SignalDistortionRatio
channel: 0 # evaluate only on channel 0, if there are multiple outputs

optim:
name: adamw
lr: 1e-4
# optimizer arguments
betas: [0.9, 0.98]
weight_decay: 1e-3

# scheduler setup
sched:
name: CosineAnnealing
# scheduler config override
warmup_steps: 10000
warmup_ratio: null
min_lr: 1e-6

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: -1
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: null
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 25 # Interval of logging.
enable_progress_bar: true
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_loss"
mode: "min"
save_top_k: 5
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints

resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.pyth
# you need to set these two to true to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
9 changes: 9 additions & 0 deletions examples/audio_tasks/process_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
pretrained_name: name of a pretrained AudioToAudioModel model (from NGC registry)
audio_dir: path to directory with audio files
dataset_manifest: path to dataset JSON manifest file (in NeMo format)
max_utts: maximum number of utterances to process

input_channel_selector: list of channels to take from audio files, defaults to `None` and takes all available channels
input_key: key for audio filepath in the manifest file, defaults to `audio_filepath`
Expand Down Expand Up @@ -80,6 +81,7 @@ class ProcessConfig:
pretrained_name: Optional[str] = None # Name of a pretrained model
audio_dir: Optional[str] = None # Path to a directory which contains audio files
dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest
max_utts: Optional[int] = None # max number of utterances to process

# Audio configs
input_channel_selector: Optional[List] = None # Union types not supported Optional[Union[List, int]]
Expand Down Expand Up @@ -171,6 +173,10 @@ def main(cfg: ProcessConfig) -> ProcessConfig:
audio_file = manifest_dir / audio_file
filepaths.append(str(audio_file.absolute()))

if cfg.max_utts is not None:
# Limit the number of utterances to process
filepaths = filepaths[: cfg.max_utts]

logging.info(f"\nProcessing {len(filepaths)} files...\n")

# setup AMP (optional)
Expand Down Expand Up @@ -225,6 +231,9 @@ def autocast():
item = json.loads(line)
item['processed_audio_filepath'] = paths2processed_files[idx]
f.write(json.dumps(item) + "\n")

if cfg.max_utts is not None and idx >= cfg.max_utts - 1:
break
else:
for idx, processed_file in enumerate(paths2processed_files):
item = {'processed_audio_filepath': processed_file}
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/data/audio_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def get_duration(audio_files: List[str]) -> List[float]:
Returns:
List of durations in seconds.
"""
duration = [librosa.get_duration(filename=f) for f in flatten(audio_files)]
duration = [librosa.get_duration(path=f) for f in flatten(audio_files)]
return duration

def load_embedding(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]:
Expand Down
8 changes: 4 additions & 4 deletions nemo/collections/asr/losses/audio_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def convolution_invariant_target(
input_length: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
filter_length: int = 512,
diag_reg: float = 1e-8,
eps: float = 1e-10,
diag_reg: float = 1e-6,
eps: float = 1e-8,
) -> torch.Tensor:
"""Calculate optimal convolution-invariant target for a given estimate.
Assumes time dimension is the last dimension in the array.
Expand Down Expand Up @@ -222,7 +222,7 @@ def calculate_sdr_batch(
convolution_filter_length: Optional[int] = 512,
remove_mean: bool = True,
sdr_max: Optional[float] = None,
eps: float = 1e-10,
eps: float = 1e-8,
) -> torch.Tensor:
"""Calculate signal-to-distortion ratio per channel.

Expand Down Expand Up @@ -310,7 +310,7 @@ def __init__(
convolution_filter_length: Optional[int] = 512,
remove_mean: bool = True,
sdr_max: Optional[float] = None,
eps: float = 1e-10,
eps: float = 1e-8,
):
super().__init__()

Expand Down
19 changes: 13 additions & 6 deletions nemo/collections/asr/models/audio_to_audio_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,24 @@ def on_test_start(self):
return super().on_test_start()

def validation_step(self, batch, batch_idx, dataloader_idx: int = 0):
return self.evaluation_step(batch, batch_idx, dataloader_idx, 'val')
output_dict = self.evaluation_step(batch, batch_idx, dataloader_idx, 'val')
if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(output_dict)
else:
self.validation_step_outputs.append(output_dict)
return output_dict

def test_step(self, batch, batch_idx, dataloader_idx=0):
return self.evaluation_step(batch, batch_idx, dataloader_idx, 'test')
output_dict = self.evaluation_step(batch, batch_idx, dataloader_idx, 'test')
if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(output_dict)
else:
self.test_step_outputs.append(output_dict)
return output_dict

def multi_evaluation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str = 'val'):
# Handle loss
loss_mean = torch.stack([x[f'{tag}_loss'] for x in outputs]).mean()
output_dict = {f'{tag}_loss': loss_mean}
tensorboard_logs = {f'{tag}_loss': loss_mean}

# Handle metrics for this tag and dataloader_idx
Expand All @@ -141,9 +150,7 @@ def multi_evaluation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str
# Store for logs
tensorboard_logs[f'{tag}_{name}'] = value

output_dict['log'] = tensorboard_logs

return output_dict
return {f'{tag}_loss': loss_mean, 'log': tensorboard_logs}

def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
return self.multi_evaluation_epoch_end(outputs, dataloader_idx, 'val')
Expand Down
53 changes: 30 additions & 23 deletions nemo/collections/asr/models/enhancement_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,24 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.decoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.decoder)

if 'mixture_consistency' in self._cfg:
logging.debug('Using mixture consistency')
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
self.mixture_consistency = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mixture_consistency)
else:
logging.debug('Mixture consistency not used')
self.mixture_consistency = None

# Future enhancement:
# If subclasses need to modify the config before calling super()
# Check ASRBPE* classes do with their mixin

# Setup augmentation
if hasattr(self.cfg, 'channel_augment') and self.cfg.channel_augment is not None:
logging.debug('Using channel augmentation')
self.channel_augmentation = EncMaskDecAudioToAudioModel.from_config_dict(self.cfg.channel_augment)
else:
logging.debug('Channel augmentation not used')
self.channel_augmentation = None

# Setup optional Optimization flags
self.setup_optimization_flags()

Expand Down Expand Up @@ -125,7 +135,7 @@ def process(
temporary_manifest_filepath = os.path.join(tmpdir, 'manifest.json')
with open(temporary_manifest_filepath, 'w', encoding='utf-8') as fp:
for audio_file in paths2audio_files:
entry = {'input_filepath': audio_file, 'duration': librosa.get_duration(filename=audio_file)}
entry = {'input_filepath': audio_file, 'duration': librosa.get_duration(path=audio_file)}
fp.write(json.dumps(entry) + '\n')

config = {
Expand Down Expand Up @@ -397,17 +407,23 @@ def training_step(self, batch, batch_idx):
if target_signal.ndim == 2:
target_signal = target_signal.unsqueeze(1)

# Apply channel augmentation
if self.training and self.channel_augmentation is not None:
input_signal = self.channel_augmentation(input=input_signal)

# Process input
processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length)

loss_value = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length)
# Calculate the loss
loss = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length)

tensorboard_logs = {
'train_loss': loss_value,
'learning_rate': self._optimizer.param_groups[0]['lr'],
'global_step': torch.tensor(self.trainer.global_step, dtype=torch.float32),
}
# Logs
self.log('train_loss', loss)
self.log('learning_rate', self._optimizer.param_groups[0]['lr'])
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))

return {'loss': loss_value, 'log': tensorboard_logs}
# Return loss
return loss

def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):
input_signal, input_length, target_signal, target_length = batch
Expand All @@ -419,11 +435,11 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =
if target_signal.ndim == 2:
target_signal = target_signal.unsqueeze(1)

# Process input
processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length)

# Prepare output
loss_value = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length)
output_dict = {f'{tag}_loss': loss_value}
# Calculate the loss
loss = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length)

# Update metrics
if hasattr(self, 'metrics') and tag in self.metrics:
Expand All @@ -432,19 +448,10 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =
metric.update(preds=processed_signal, target=target_signal, input_length=input_length)

# Log global step
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32), sync_dist=True)
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))

if tag == 'val':
if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(output_dict)
else:
self.validation_step_outputs.append(output_dict)
else:
if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(output_dict)
else:
self.test_step_outputs.append(output_dict)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved handling of multiple dataloaders to audio_to_audio_model to validation_step and test_step.

return output_dict
# Return loss
return {f'{tag}_loss': loss}

@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
Expand Down
7 changes: 6 additions & 1 deletion nemo/collections/asr/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.asr.modules.audio_modules import MaskBasedBeamformer, MaskEstimatorRNN, MaskReferenceChannel
from nemo.collections.asr.modules.audio_modules import (
MaskBasedBeamformer,
MaskEstimatorFlexChannels,
MaskEstimatorRNN,
MaskReferenceChannel,
)
from nemo.collections.asr.modules.audio_preprocessing import (
AudioToMelSpectrogramPreprocessor,
AudioToMFCCPreprocessor,
Expand Down
Loading
Loading