Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix metrics for SE tutorial #7604

Merged
merged 2 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions nemo/collections/asr/models/audio_to_audio_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import hydra
import torch
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer

from nemo.collections.asr.metrics.audio import AudioMetricWrapper
Expand Down Expand Up @@ -67,12 +67,15 @@ def _setup_metrics(self, tag: str = 'val'):
logging.debug('Found %d metrics for tag %s, not necesary to initialize again', num_dataloaders, tag)
return

if 'metrics' not in self._cfg or tag not in self._cfg['metrics']:
if self.cfg.get('metrics') is None:
# Metrics are not available in the configuration, nothing to do
logging.debug('No metrics configured for %s in model.metrics.%s', tag, tag)
logging.debug('No metrics configured in model.metrics')
return

metrics_cfg = self._cfg['metrics'][tag]
if (metrics_cfg := self.cfg['metrics'].get(tag)) is None:
# Metrics configuration is not available in the configuration, nothing to do
logging.debug('No metrics configured for %s in model.metrics', tag)
return

if 'loss' in metrics_cfg:
raise ValueError(
Expand All @@ -86,16 +89,19 @@ def _setup_metrics(self, tag: str = 'val'):
# Setup metrics for each dataloader
self.metrics[tag] = torch.nn.ModuleList()
for dataloader_idx in range(num_dataloaders):
metrics_dataloader_idx = torch.nn.ModuleDict(
{
name: AudioMetricWrapper(
metric=hydra.utils.instantiate(cfg),
channel=cfg.get('channel'),
metric_using_batch_averaging=cfg.get('metric_using_batch_averaging'),
)
for name, cfg in metrics_cfg.items()
}
)
metrics_dataloader_idx = {}
for name, cfg in metrics_cfg.items():
logging.debug('Initialize %s for dataloader_idx %s', name, dataloader_idx)
cfg_dict = OmegaConf.to_container(cfg)
cfg_channel = cfg_dict.pop('channel', None)
cfg_batch_averaging = cfg_dict.pop('metric_using_batch_averaging', None)
metrics_dataloader_idx[name] = AudioMetricWrapper(
metric=hydra.utils.instantiate(cfg_dict),
channel=cfg_channel,
metric_using_batch_averaging=cfg_batch_averaging,
)

metrics_dataloader_idx = torch.nn.ModuleDict(metrics_dataloader_idx)
self.metrics[tag].append(metrics_dataloader_idx.to(self.device))

logging.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,6 @@
"from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest\n",
"\n",
"\n",
"# Used to download data processing scripts\n",
"USER = 'anteju' # TODO: change to 'NVIDIA'\n",
"BRANCH = 'dev/se-tutorial' # TODO: change to 'r1.21.0'\n",
"\n",
"\n",
"# Utility functions for displaying signals and metrics\n",
"def show_signal(signal: np.ndarray, sample_rate: int = 16000, tag: str = 'Signal'):\n",
" \"\"\"Show the time-domain signal and its spectrogram.\n",
Expand Down Expand Up @@ -607,7 +602,7 @@
" '_target_': 'torchmetrics.audio.SignalDistortionRatio',\n",
" }\n",
"})\n",
"config.model.metrics.validation = metrics\n",
"config.model.metrics.val = metrics\n",
"config.model.metrics.test = metrics\n",
"\n",
"print(\"Metrics config:\")\n",
Expand Down Expand Up @@ -1112,7 +1107,7 @@
" 'channel': 1,\n",
" },\n",
"})\n",
"config_dual_output.model.metrics.validation = metrics\n",
"config_dual_output.model.metrics.val = metrics\n",
"config_dual_output.model.metrics.test = metrics"
]
},
Expand Down
Loading