Skip to content

Commit

Permalink
Added consistency projection, addressed comments for the notebook
Browse files Browse the repository at this point in the history
Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju committed May 9, 2023
1 parent 6d95580 commit 8d58f87
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 22 deletions.
9 changes: 9 additions & 0 deletions nemo/collections/asr/models/enhancement_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.mask_processor = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mask_processor)
self.decoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.decoder)

if 'mixture_consistency' in self._cfg:
self.mixture_consistency = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mixture_consistency)
else:
self.mixture_consistency = None

# Future enhancement:
# If subclasses need to modify the config before calling super()
# Check ASRBPE* classes do with their mixin
Expand Down Expand Up @@ -370,6 +375,10 @@ def forward(self, input_signal, input_length=None):
# Mask-based processor in the encoded domain
processed, processed_length = self.mask_processor(input=encoded, input_length=encoded_length, mask=mask)

# Mixture consistency
if self.mixture_consistency is not None:
processed = self.mixture_consistency(mixture=encoded, estimate=processed)

# Decoder
processed, processed_length = self.decoder(input=processed, input_length=processed_length)

Expand Down
66 changes: 66 additions & 0 deletions nemo/collections/asr/modules/audio_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,3 +878,69 @@ def forward(
output, output_length = self.filter(input=output, input_length=input_length, power=power)

return output.to(io_dtype), output_length


class MixtureConsistencyProjection(NeuralModule):
"""Ensure estimated sources are consistent with the input mixture.
Note that the input mixture is assume to be a single-channel signal.
Args:
weighting: Optional weighting mode for the consistency constraint.
If `None`, use uniform weighting. If `power`, use the power of the
estimated source as the weight.
eps: Small positive value for regularization
Reference:
Wisdom et al., Differentiable consistency constraints for improved deep speech enhancement, 2018
"""

def __init__(self, weighting: Optional[str] = None, eps: float = 1e-8):
super().__init__()
self.weighting = weighting
self.eps = eps

@property
def input_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports.
"""
return {
"mixture": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"estimate": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
}

@property
def output_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports.
"""
return {
"output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
}

@typecheck()
def forward(self, mixture: torch.Tensor, estimate: torch.Tensor) -> torch.Tensor:
"""Enforce mixture consistency on the estimated sources.
Args:
mixture: Single-channel mixture, shape (B, 1, F, N)
estimate: M estimated sources, shape (B, M, F, N)
Returns:
Source estimates consistent with the mixture, shape (B, M, F, N)
"""
# number of sources
M = estimate.size(-3)
# estimated mixture based on the estimated sources
estimated_mixture = torch.sum(estimate, dim=-3, keepdim=True)

# weighting
if self.weighting == None:

Check notice

Code scanning / CodeQL

Testing equality to None Note

Testing for None should use the 'is' operator.
weight = 1 / M
elif self.weighting == 'power':
weight = estimate.abs().pow(2)
weight = weight / (weight.sum(dim=-3, keepdim=True) + self.eps)
else:
raise NotImplementedError(f'Weighting mode {self.weighting_mode} not implemented')

# consistent estimate
consistent_estimate = estimate + weight * (mixture - estimated_mixture)

return consistent_estimate
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
"import pytorch_lightning as pl\n",
"import soundfile as sf\n",
"\n",
"from omegaconf import OmegaConf, open_dict\n",
"from pathlib import Path\n",
"from torchmetrics.functional.audio import signal_distortion_ratio, scale_invariant_signal_distortion_ratio\n",
"\n",
Expand Down Expand Up @@ -462,8 +463,6 @@
"metadata": {},
"outputs": [],
"source": [
"from omegaconf import OmegaConf\n",
"\n",
"config_dir = root_dir / 'conf'\n",
"config_dir.mkdir(exist_ok=True)\n",
"\n",
Expand Down Expand Up @@ -561,14 +560,14 @@
"outputs": [],
"source": [
"# Setup metrics to compute on validation and test sets\n",
"metrics = {\n",
"metrics = OmegaConf.create({\n",
" 'sisdr': {\n",
" '_target_': 'torchmetrics.audio.ScaleInvariantSignalDistortionRatio',\n",
" },\n",
" 'sdr': {\n",
" '_target_': 'torchmetrics.audio.SignalDistortionRatio',\n",
" }\n",
"}\n",
"})\n",
"config.model.metrics.validation = metrics\n",
"config.model.metrics.test = metrics\n",
"\n",
Expand Down Expand Up @@ -598,6 +597,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "a9201689",
"metadata": {},
Expand Down Expand Up @@ -723,7 +723,7 @@
"metadata": {},
"outputs": [],
"source": [
"trainer.fit(enhancement_model)"
"# trainer.fit(enhancement_model)"
]
},
{
Expand All @@ -742,7 +742,7 @@
"metadata": {},
"outputs": [],
"source": [
"trainer.test(enhancement_model, ckpt_path=None)"
"# trainer.test(enhancement_model, ckpt_path=None)"
]
},
{
Expand Down Expand Up @@ -778,7 +778,8 @@
"\n",
"# Process using the model\n",
"noisy_tensor = torch.tensor(noisy_signal).reshape(1, 1, -1) # (batch, channel, time)\n",
"output_tensor, _ = enhancement_model(input_signal=noisy_tensor)\n",
"with torch.no_grad():\n",
" output_tensor, _ = enhancement_model(input_signal=noisy_tensor)\n",
"output_signal = output_tensor[0][0].detach().numpy()"
]
},
Expand Down Expand Up @@ -842,7 +843,8 @@
"min_mask_db = -10\n",
"enhancement_model.mask_processor.mask_min = db2mag(min_mask_db)\n",
"\n",
"output_tensor_min_mask, _ = enhancement_model(input_signal=noisy_tensor)\n",
"with torch.no_grad():\n",
" output_tensor_min_mask, _ = enhancement_model(input_signal=noisy_tensor)\n",
"output_signal_min_mask = output_tensor_min_mask[0][0].detach().numpy()\n",
"\n",
"print('Noisy signal')\n",
Expand Down Expand Up @@ -952,7 +954,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 32,
"id": "096f8c25",
"metadata": {},
"outputs": [],
Expand All @@ -962,12 +964,37 @@
"config_dual_output.model.loss.weight = [0.5, 0.5]"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "a0ca1eda",
"metadata": {},
"source": [
"A mixture consistency layer can be added to enforce the estimated sources (speech and noise, in this case) to be consistent with the input mixture [5]."
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "a967f683",
"metadata": {},
"outputs": [],
"source": [
"# Add a mixture consistency projection\n",
"with open_dict(config_dual_output):\n",
" config_dual_output.model.mixture_consistency = OmegaConf.create({\n",
" '_target_': 'nemo.collections.asr.modules.audio_modules.MixtureConsistencyProjection',\n",
" 'weighting': 'power',\n",
" })"
]
},
{
"cell_type": "markdown",
"id": "5b9574a6",
"metadata": {},
"source": [
"Similarly, it is possible to calculate metrics for each output channel separately.\n",
"Metrics can be calculated for each output channel separately.\n",
"If a channel parameter is not provided, the configured metrics are averaged across all channels.\n",
"For example, this can be configured as follows:"
]
},
Expand All @@ -979,7 +1006,7 @@
"outputs": [],
"source": [
"# Setup metrics\n",
"metrics = {\n",
"metrics = OmegaConf.create({\n",
" # Calculate speech metric using the first channel\n",
" 'speech_sisdr': {\n",
" '_target_': 'torchmetrics.audio.ScaleInvariantSignalDistortionRatio',\n",
Expand All @@ -998,7 +1025,7 @@
" '_target_': 'torchmetrics.audio.SignalDistortionRatio',\n",
" 'channel': 1,\n",
" },\n",
"}\n",
"})\n",
"config_dual_output.model.metrics.validation = metrics\n",
"config_dual_output.model.metrics.test = metrics"
]
Expand All @@ -1025,16 +1052,14 @@
"config_dual_output.trainer.accelerator = accelerator\n",
"\n",
"# Reduces maximum number of epochs for quick demonstration\n",
"config_dual_output.trainer.max_epochs = 10\n",
"config_dual_output.trainer.max_epochs = 100 # 10\n",
"\n",
"# Remove distributed training flags\n",
"config_dual_output.trainer.strategy = None\n",
"\n",
"# Instantiate the trainer\n",
"trainer = pl.Trainer(**config_dual_output.trainer)\n",
"\n",
"exp_dir = exp_manager(trainer, config_dual_output.get(\"exp_manager\", None))\n",
"# The exp_dir provides a path to the current experiment for easy access\n",
"\n",
"print(\"Experiment directory:\")\n",
"print(exp_dir)"
Expand All @@ -1056,6 +1081,7 @@
"metadata": {},
"outputs": [],
"source": [
"dual_output_model = nemo_asr.models.EncMaskDecAudioToAudioModel(cfg=config_dual_output.model, trainer=trainer)\n",
"trainer.fit(dual_output_model)"
]
},
Expand All @@ -1071,7 +1097,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "1f17ed63",
"id": "5f9d68b2",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -1096,7 +1122,9 @@
"source": [
"# (batch, channel, time)\n",
"noisy_tensor = torch.tensor(noisy_signal).reshape(1, 1, -1)\n",
"processed_tensor, _ = dual_output_model(input_signal=noisy_tensor)\n",
"\n",
"with torch.no_grad():\n",
" processed_tensor, _ = dual_output_model(input_signal=noisy_tensor)\n",
"\n",
"# First output channel is the speech estimate\n",
"output_speech = processed_tensor[0][0].detach().numpy()\n",
Expand All @@ -1120,22 +1148,49 @@
"ipd.display(ipd.Audio(output_noise, rate=sample_rate))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "cd6da1f0",
"metadata": {},
"source": [
"## Next steps\n",
"This is a simple tutorial which can serve as a starting point for prototyping and experimentation with audio-to-audio models.\n",
"A processed audio output can be used, for example, for ASR or TTS.\n",
"\n",
"For more details about NeMo models and applications in in ASR and TTS, we recommend you checkout other tutorials next:\n",
"\n",
"* [NeMo fundamentals](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/00_NeMo_Primer.ipynb)\n",
"* [NeMo models](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/01_NeMo_Models.ipynb)\n",
"* [Speech Recognition](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/asr/ASR_with_NeMo.ipynb)\n",
"* [Speech Synthesis](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/tts/Inference_ModelSelect.ipynb)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "46a855e3",
"metadata": {},
"source": [
"### References\n",
"## References\n",
"\n",
"[1] V. Panayotov, G. Chen, D. Povery, S. Khudanpur, LibriSpeech: An ASR corpus based on public domain audio books, ICASSP 2015\n",
"[1] V. Panayotov, G. Chen, D. Povery, S. Khudanpur, \"LibriSpeech: An ASR corpus based on public domain audio books,\" ICASSP 2015\n",
"\n",
"[2] J. Thieman, N. Ito, V. Emmanuel, DEMAND: collection of multi-channel recordings of acoustic noise in diverse environments, ICA 2013\n",
"[2] J. Thieman, N. Ito, V. Emmanuel, \"DEMAND: collection of multi-channel recordings of acoustic noise in diverse environments,\" ICA 2013\n",
"\n",
"[3] K. Kinoshita, T. Ochiai, M. Delcroix, T. Nakatani, Improving noise robust automatic speech recognition with single-channel time-domain enhancement network, ICASSP 2020.\n",
"[3] K. Kinoshita, T. Ochiai, M. Delcroix, T. Nakatani, \"Improving noise robust automatic speech recognition with single-channel time-domain enhancement network,\" ICASSP 2020.\n",
"\n",
"[4] https://github.com/Lightning-AI/torchmetrics"
"[4] https://github.com/Lightning-AI/torchmetrics\n",
"\n",
"[5] Wisdom et al., Differentiable consistency constraints for improved deep speech enhancement, ICASSP 2018"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "c5a97424",
"metadata": {},
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit 8d58f87

Please sign in to comment.