Skip to content

Commit

Permalink
Fix metrics for SE tutorial
Browse files Browse the repository at this point in the history
Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju committed Oct 2, 2023
1 parent d0acb40 commit 12e0310
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 21 deletions.
34 changes: 20 additions & 14 deletions nemo/collections/asr/models/audio_to_audio_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import hydra
import torch
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer

from nemo.collections.asr.metrics.audio import AudioMetricWrapper
Expand Down Expand Up @@ -67,12 +67,15 @@ def _setup_metrics(self, tag: str = 'val'):
logging.debug('Found %d metrics for tag %s, not necesary to initialize again', num_dataloaders, tag)
return

if 'metrics' not in self._cfg or tag not in self._cfg['metrics']:
if self.cfg.get('metrics') is None:
# Metrics are not available in the configuration, nothing to do
logging.debug('No metrics configured for %s in model.metrics.%s', tag, tag)
logging.debug('No metrics configured in model.metrics')
return

metrics_cfg = self._cfg['metrics'][tag]
if (metrics_cfg := self.cfg['metrics'].get(tag)) is None:
# Metrics configuration is not available in the configuration, nothing to do
logging.debug('No metrics configured for %s in model.metrics', tag)
return

if 'loss' in metrics_cfg:
raise ValueError(
Expand All @@ -86,16 +89,19 @@ def _setup_metrics(self, tag: str = 'val'):
# Setup metrics for each dataloader
self.metrics[tag] = torch.nn.ModuleList()
for dataloader_idx in range(num_dataloaders):
metrics_dataloader_idx = torch.nn.ModuleDict(
{
name: AudioMetricWrapper(
metric=hydra.utils.instantiate(cfg),
channel=cfg.get('channel'),
metric_using_batch_averaging=cfg.get('metric_using_batch_averaging'),
)
for name, cfg in metrics_cfg.items()
}
)
metrics_dataloader_idx = {}
for name, cfg in metrics_cfg.items():
logging.debug('Initialize %s for dataloader_idx %s', name, dataloader_idx)
cfg_dict = OmegaConf.to_container(cfg)
cfg_channel = cfg_dict.pop('channel', None)
cfg_batch_averaging = cfg_dict.pop('metric_using_batch_averaging', None)
metrics_dataloader_idx[name] = AudioMetricWrapper(
metric=hydra.utils.instantiate(cfg_dict),
channel=cfg_channel,
metric_using_batch_averaging=cfg_batch_averaging,
)

metrics_dataloader_idx = torch.nn.ModuleDict(metrics_dataloader_idx)
self.metrics[tag].append(metrics_dataloader_idx.to(self.device))

logging.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,6 @@
"from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest\n",
"\n",
"\n",
"# Used to download data processing scripts\n",
"USER = 'anteju' # TODO: change to 'NVIDIA'\n",
"BRANCH = 'dev/se-tutorial' # TODO: change to 'r1.21.0'\n",
"\n",
"\n",
"# Utility functions for displaying signals and metrics\n",
"def show_signal(signal: np.ndarray, sample_rate: int = 16000, tag: str = 'Signal'):\n",
" \"\"\"Show the time-domain signal and its spectrogram.\n",
Expand Down Expand Up @@ -607,7 +602,7 @@
" '_target_': 'torchmetrics.audio.SignalDistortionRatio',\n",
" }\n",
"})\n",
"config.model.metrics.validation = metrics\n",
"config.model.metrics.val = metrics\n",
"config.model.metrics.test = metrics\n",
"\n",
"print(\"Metrics config:\")\n",
Expand Down Expand Up @@ -1112,7 +1107,7 @@
" 'channel': 1,\n",
" },\n",
"})\n",
"config_dual_output.model.metrics.validation = metrics\n",
"config_dual_output.model.metrics.val = metrics\n",
"config_dual_output.model.metrics.test = metrics"
]
},
Expand Down

0 comments on commit 12e0310

Please sign in to comment.