From d8cf7381304d65808df6569e7a38f532efba9468 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ante=20Jukic=CC=81?= Date: Wed, 8 Feb 2023 15:29:59 -0800 Subject: [PATCH] [ASR] Added a script for evaluating audio-to-audio metrics for a manifest file (audio_to_audio_eval.py) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ante Jukić --- .../asr/audio_to_audio/audio_to_audio_eval.py | 278 ++++++++++++++++++ examples/asr/audio_to_audio/process_audio.py | 9 +- nemo/collections/asr/metrics/audio.py | 3 - .../asr/models/enhancement_models.py | 1 - .../common/parts/preprocessing/manifest.py | 30 +- 5 files changed, 302 insertions(+), 19 deletions(-) create mode 100644 examples/asr/audio_to_audio/audio_to_audio_eval.py diff --git a/examples/asr/audio_to_audio/audio_to_audio_eval.py b/examples/asr/audio_to_audio/audio_to_audio_eval.py new file mode 100644 index 000000000000..57d7095057a9 --- /dev/null +++ b/examples/asr/audio_to_audio/audio_to_audio_eval.py @@ -0,0 +1,278 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Script to compute metrics for a given audio-to-audio model for a given manifest file for some dataset. +The manifest file must include path to input audio and path to target (ground truth) audio. + +Note: This scripts depends on the `process_audio.py` script, and therefore both scripts should be +located in the same directory during execution. + +# Arguments + +<< All arguments of `process_audio.py` are inherited by this script, so please refer to `process_audio.py` +for full list of arguments >> + + dataset_manifest: Required - path to dataset JSON manifest file (in NeMo format) + output_dir: Optional - output directory where the processed audio will be saved + metrics: Optional - list of metrics to evaluate. Defaults to [sdr,estoi] + sample_rate: Optional - sample rate for loaded audio. Defaults to 16kHz. + only_score_manifest: Optional - If set, processing will be skipped and it is assumed the processed audio is available in dataset_manifest + +# Usage + +## To score a dataset with a manifest file that contains the input audio which needs to be processed and target audio + +python audio_to_audio_eval.py \ + model_path=null \ + pretrained_model=null \ + dataset_manifest= \ + output_dir= \ + processed_channel_selector= \ + target_key= \ + target_channel_selector= \ + metrics= + batch_size=32 \ + amp=True + +## To score a manifest file which has been previously processed and contains both processed audio and target audio + +python audio_to_audio_eval.py \ + dataset_manifest= \ + processed_key= + processed_channel_selector= \ + target_key= \ + target_channel_selector= \ + metrics= + batch_size=32 \ + amp=True +""" +import json +import os +import tempfile +from dataclasses import dataclass, field, is_dataclass +from typing import List, Optional + +import process_audio +import torch +from omegaconf import OmegaConf, open_dict +from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality +from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio +from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility +from tqdm import tqdm + +from nemo.collections.asr.data import audio_to_audio_dataset +from nemo.collections.asr.metrics.audio import AudioMetricWrapper +from nemo.collections.common.parts.preprocessing import manifest +from nemo.core.config import hydra_runner +from nemo.utils import logging + + +@dataclass +class AudioEvaluationConfig(process_audio.ProcessConfig): + # Processed audio config + processed_channel_selector: Optional[List] = None + processed_key: str = 'processed_audio_filepath' + + # Target audio configs + target_dataset_dir: Optional[str] = None # If not provided, defaults to dirname(cfg.dataset_manifest) + target_channel_selector: Optional[List] = None + target_key: str = 'target_audio_filepath' + + # Sample rate for audio evaluation + sample_rate: int = 16000 + + # Score an existing manifest without running processing + only_score_manifest: bool = False + + # Metrics to calculate + metrics: List[str] = field(default_factory=lambda: ['sdr', 'estoi']) + + +def get_evaluation_dataloader(config): + """Prepare a dataloader for evaluation. + """ + dataset = audio_to_audio_dataset.get_audio_to_target_dataset(config=config) + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=dataset.collate_fn, + drop_last=config.get('drop_last', False), + shuffle=False, + num_workers=config.get('num_workers', min(config['batch_size'], os.cpu_count() - 1)), + pin_memory=True, + ) + + +def get_metrics(cfg: AudioEvaluationConfig): + """Prepare a dictionary with metrics. + """ + available_metrics = ['sdr', 'sisdr', 'stoi', 'estoi', 'pesq'] + + metrics = dict() + for name in sorted(set(cfg.metrics)): + name = name.lower() + if name == 'sdr': + metric = AudioMetricWrapper(metric=SignalDistortionRatio()) + elif name == 'sisdr': + metric = AudioMetricWrapper(metric=ScaleInvariantSignalDistortionRatio()) + elif name == 'stoi': + metric = AudioMetricWrapper(metric=ShortTimeObjectiveIntelligibility(fs=cfg.sample_rate, extended=False)) + elif name == 'estoi': + metric = AudioMetricWrapper(metric=ShortTimeObjectiveIntelligibility(fs=cfg.sample_rate, extended=True)) + elif name == 'pesq': + metric = AudioMetricWrapper(metric=PerceptualEvaluationSpeechQuality(fs=cfg.sample_rate, mode='wb')) + else: + raise ValueError(f'Unexpected metric: {name}. Currently available metrics: {available_metrics}') + + metrics[name] = metric + + return metrics + + +@hydra_runner(config_name="AudioEvaluationConfig", schema=AudioEvaluationConfig) +def main(cfg: AudioEvaluationConfig): + torch.set_grad_enabled(False) + + if is_dataclass(cfg): + cfg = OmegaConf.structured(cfg) + + if cfg.audio_dir is not None: + raise RuntimeError( + "Evaluation script requires ground truth audio to be passed via a manifest file. " + "If manifest file is available, submit it via `dataset_manifest` argument." + ) + + if not os.path.exists(cfg.dataset_manifest): + raise FileNotFoundError(f'The dataset manifest file could not be found at path : {cfg.dataset_manifest}') + + if cfg.target_dataset_dir is None: + # Assume the target data is available in the same directory as the input data + cfg.target_dataset_dir = os.path.dirname(cfg.dataset_manifest) + elif not os.path.isdir(cfg.target_dataset_dir): + raise FileNotFoundError(f'Target dataset dir could not be found at path : {cfg.target_dataset_dir}') + + # Setup metrics + metrics = get_metrics(cfg) + + # Processing + if not cfg.only_score_manifest: + # Process audio using the configured model and save in the output directory + process_cfg = process_audio.main(cfg) # type: ProcessConfig + + # Release GPU memory if it was used during transcription + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + logging.info('Finished processing audio.') + else: + # Score the input manifest, no need to run a model + cfg.output_filename = cfg.dataset_manifest + process_cfg = cfg + + # Evaluation + with tempfile.TemporaryDirectory() as tmp_dir: + # Prepare a temporary manifest with processed audio and target + temporary_manifest_filepath = os.path.join(tmp_dir, 'manifest.json') + + num_files = 0 + + with open(process_cfg.output_filename, 'r') as f_processed, open( + temporary_manifest_filepath, 'w', encoding='utf-8' + ) as f_tmp: + for line_processed in f_processed: + data_processed = json.loads(line_processed) + + if cfg.processed_key not in data_processed: + raise ValueError( + f'Processed key {cfg.processed_key} not found in manifest: {process_cfg.output_filename}.' + ) + + if cfg.target_key not in data_processed: + raise ValueError( + f'Target key {cfg.target_key} not found in manifest: {process_cfg.output_filename}.' + ) + + item = { + 'processed': manifest.get_full_path( + audio_file=data_processed[cfg.processed_key], manifest_file=process_cfg.output_filename + ), + 'target': manifest.get_full_path( + audio_file=data_processed[cfg.target_key], data_dir=cfg.target_dataset_dir + ), + 'duration': data_processed.get('duration'), + } + + # Double-check files exist + for key in ['processed', 'target']: + if not os.path.isfile(item[key]): + raise ValueError(f'File for key "{key}" not found at: {item[key]}.\nCurrent item: {item}') + + # Warn if we're comparing the same files + if item['target'] == item['processed']: + logging.warning('Using the same file as processed and target: %s', item['target']) + + # Write the entry in the temporary manifest file + f_tmp.write(json.dumps(item) + '\n') + + num_files += 1 + + # Prepare dataloader + config = { + 'manifest_filepath': temporary_manifest_filepath, + 'sample_rate': cfg.sample_rate, + 'input_key': 'processed', + 'input_channel_selector': cfg.processed_channel_selector, + 'target_key': 'target', + 'target_channel_selector': cfg.target_channel_selector, + 'batch_size': min(cfg.batch_size, num_files), + 'num_workers': cfg.num_workers, + } + temporary_dataloader = get_evaluation_dataloader(config) + + # Calculate metrics + for eval_batch in tqdm(temporary_dataloader, desc='Evaluating'): + processed_signal, processed_length, target_signal, target_length = eval_batch + + if not torch.equal(processed_length, target_length): + raise RuntimeError(f'Length mismatch.') + + for name, metric in metrics.items(): + metric.update(preds=processed_signal, target=target_signal, input_length=target_length) + + # Convert to a dictionary with name: value + metrics_value = {name: metric.compute().item() for name, metric in metrics.items()} + + logging.info('Finished running evaluation.') + + # Show results + logging.info('Summary\n') + logging.info('Data') + logging.info('\tmanifest: %s', cfg.output_filename) + logging.info('\ttarget_dataset_dir: %s', cfg.target_dataset_dir) + logging.info('\tnum_files: %s', num_files) + logging.info('Metrics') + for name, value in metrics_value.items(): + logging.info('\t%10s: \t%6.2f', name, value) + + # Inject the metric name and score into the config, and return the entire config + with open_dict(cfg): + cfg.metrics_value = metrics_value + + return cfg + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/asr/audio_to_audio/process_audio.py b/examples/asr/audio_to_audio/process_audio.py index 6441f685ca4f..20650d8a8c3c 100644 --- a/examples/asr/audio_to_audio/process_audio.py +++ b/examples/asr/audio_to_audio/process_audio.py @@ -198,13 +198,11 @@ def autocast(): # if transcripts should not be overwritten, and already exists, skip re-transcription step and return if not cfg.overwrite_output and os.path.exists(cfg.output_dir): - logging.warning( + raise RuntimeError( f"Previous output found at {cfg.output_dir}, and flag `overwrite_output`" f"is {cfg.overwrite_output}. Returning without processing." ) - return cfg - # Process audio with autocast(): with torch.no_grad(): @@ -225,12 +223,11 @@ def autocast(): with open(cfg.dataset_manifest, 'r') as fr: for idx, line in enumerate(fr): item = json.loads(line) - item['audio_filepath_unprocessed'] = item[input_key] - item['audio_filepath'] = paths2processed_files[idx] + item['processed_audio_filepath'] = paths2processed_files[idx] f.write(json.dumps(item) + "\n") else: for idx, processed_file in enumerate(paths2processed_files): - item = {'audio_filepath': processed_file} + item = {'processed_audio_filepath': processed_file} f.write(json.dumps(item) + "\n") return cfg diff --git a/nemo/collections/asr/metrics/audio.py b/nemo/collections/asr/metrics/audio.py index df48c4a8c583..5e8c2915e3fa 100644 --- a/nemo/collections/asr/metrics/audio.py +++ b/nemo/collections/asr/metrics/audio.py @@ -134,9 +134,6 @@ def update(self, preds: torch.Tensor, target: torch.Tensor, input_length: Option target: tensor with target signals, shape (B, C, T) input_length: Optional, input tensor with length (in samples) of each signal in the batch, shape (B,). If not provided, it is assumed that all samples are valid. - - Returns: - Underlying metric averaged on the current batch. """ preds, target = self._select_channel(preds=preds, target=target) diff --git a/nemo/collections/asr/models/enhancement_models.py b/nemo/collections/asr/models/enhancement_models.py index 9d5b711b9e1c..08a2b648aa32 100644 --- a/nemo/collections/asr/models/enhancement_models.py +++ b/nemo/collections/asr/models/enhancement_models.py @@ -129,7 +129,6 @@ def process( 'input_channel_selector': input_channel_selector, 'batch_size': min(batch_size, len(paths2audio_files)), 'num_workers': num_workers, - 'channel_selector': input_channel_selector, } # Create output dir if necessary diff --git a/nemo/collections/common/parts/preprocessing/manifest.py b/nemo/collections/common/parts/preprocessing/manifest.py index efe850fc0e8c..65147f953d0e 100644 --- a/nemo/collections/common/parts/preprocessing/manifest.py +++ b/nemo/collections/common/parts/preprocessing/manifest.py @@ -159,7 +159,12 @@ def __parse_item(line: str, manifest_file: str) -> Dict[str, Any]: return item -def get_full_path(audio_file: str, manifest_file: str, audio_file_len_limit: int = 255) -> str: +def get_full_path( + audio_file: str, + manifest_file: Optional[str] = None, + data_dir: Optional[str] = None, + audio_file_len_limit: int = 255, +) -> str: """Get full path to audio_file. If the audio_file is a relative path and does not exist, @@ -169,8 +174,9 @@ def get_full_path(audio_file: str, manifest_file: str, audio_file_len_limit: int Args: audio_file: path to an audio file, either absolute or assumed relative - to the manifest directory + to the manifest directory or data directory manifest_file: path to a manifest file + data_dir: path to a directory containing data, use only if a manifest file is not provided audio_file_len_limit: limit for length of audio_file when using relative paths Returns: @@ -178,15 +184,21 @@ def get_full_path(audio_file: str, manifest_file: str, audio_file_len_limit: int """ audio_file = Path(audio_file) - if is_datastore_path(manifest_file): - # WORKAROUND: pathlib does not support URIs, so use os.path - manifest_dir = os.path.dirname(manifest_file) - else: - manifest_dir = Path(manifest_file).parent.as_posix() + if manifest_file is None and data_dir is None: + raise ValueError(f'Use either manifest_file or data_dir to specify the data directory.') + elif manifest_file is not None and data_dir is not None: + raise ValueError(f'Parameters manifest_file and data_dir cannot be used simultaneously.') + + if data_dir is None: + if is_datastore_path(manifest_file): + # WORKAROUND: pathlib does not support URIs, so use os.path + data_dir = os.path.dirname(manifest_file) + else: + data_dir = Path(manifest_file).parent.as_posix() if (len(str(audio_file)) < audio_file_len_limit) and not audio_file.is_file() and not audio_file.is_absolute(): - # assume audio_file path is relative to manifest_dir - audio_file_path = os.path.join(manifest_dir, audio_file.as_posix()) + # assume audio_file path is relative to data_dir + audio_file_path = os.path.join(data_dir, audio_file.as_posix()) if is_datastore_path(audio_file_path): # If audio was originally on an object store, use locally-cached path