Skip to content

Commit

Permalink
[ASR] Added a script for evaluating audio-to-audio metrics for a mani…
Browse files Browse the repository at this point in the history
…fest file (audio_to_audio_eval.py)

Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju committed Feb 21, 2023
1 parent 0e70135 commit 38bda00
Show file tree
Hide file tree
Showing 5 changed files with 302 additions and 19 deletions.
278 changes: 278 additions & 0 deletions examples/asr/audio_to_audio/audio_to_audio_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Script to compute metrics for a given audio-to-audio model for a given manifest file for some dataset.
The manifest file must include path to input audio and path to target (ground truth) audio.
Note: This scripts depends on the `process_audio.py` script, and therefore both scripts should be
located in the same directory during execution.
# Arguments
<< All arguments of `process_audio.py` are inherited by this script, so please refer to `process_audio.py`
for full list of arguments >>
dataset_manifest: Required - path to dataset JSON manifest file (in NeMo format)
output_dir: Optional - output directory where the processed audio will be saved
metrics: Optional - list of metrics to evaluate. Defaults to [sdr,estoi]
sample_rate: Optional - sample rate for loaded audio. Defaults to 16kHz.
only_score_manifest: Optional - If set, processing will be skipped and it is assumed the processed audio is available in dataset_manifest
# Usage
## To score a dataset with a manifest file that contains the input audio which needs to be processed and target audio
python audio_to_audio_eval.py \
model_path=null \
pretrained_model=null \
dataset_manifest=<Mandatory: path to a dataset manifest file> \
output_dir=<Optional: Directory where processed audio will be saved> \
processed_channel_selector=<Optional: list of channels to select from the processed audio file> \
target_key=<Optional: key for the target audio in the dataset manifest. Default: target_audio_filepath> \
target_channel_selector=<Optional: list of channels to select from the target audio file> \
metrics=<Optional: list of metrics to evaluate. Defaults to [sdr,estoi]>
batch_size=32 \
amp=True
## To score a manifest file which has been previously processed and contains both processed audio and target audio
python audio_to_audio_eval.py \
dataset_manifest=<Mandatory: path to a dataset manifest file> \
processed_key=<Optional: key for the target audio in the dataset manifest. Default: processed_audio_filepath>
processed_channel_selector=<Optional: list of channels to select from the processed audio file> \
target_key=<Optional: key for the target audio in the dataset manifest. Default: target_audio_filepath> \
target_channel_selector=<Optional: list of channels to select from the target audio file> \
metrics=<Optional: list of metrics to evaluate. Defaults to [sdr,estoi]>
batch_size=32 \
amp=True
"""
import json
import os
import tempfile
from dataclasses import dataclass, field, is_dataclass
from typing import List, Optional

import process_audio
import torch
from omegaconf import OmegaConf, open_dict
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
from tqdm import tqdm

from nemo.collections.asr.data import audio_to_audio_dataset
from nemo.collections.asr.metrics.audio import AudioMetricWrapper
from nemo.collections.common.parts.preprocessing import manifest
from nemo.core.config import hydra_runner
from nemo.utils import logging


@dataclass
class AudioEvaluationConfig(process_audio.ProcessConfig):
# Processed audio config
processed_channel_selector: Optional[List] = None
processed_key: str = 'processed_audio_filepath'

# Target audio configs
target_dataset_dir: Optional[str] = None # If not provided, defaults to dirname(cfg.dataset_manifest)
target_channel_selector: Optional[List] = None
target_key: str = 'target_audio_filepath'

# Sample rate for audio evaluation
sample_rate: int = 16000

# Score an existing manifest without running processing
only_score_manifest: bool = False

# Metrics to calculate
metrics: List[str] = field(default_factory=lambda: ['sdr', 'estoi'])


def get_evaluation_dataloader(config):
"""Prepare a dataloader for evaluation.
"""
dataset = audio_to_audio_dataset.get_audio_to_target_dataset(config=config)

return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=config['batch_size'],
collate_fn=dataset.collate_fn,
drop_last=config.get('drop_last', False),
shuffle=False,
num_workers=config.get('num_workers', min(config['batch_size'], os.cpu_count() - 1)),
pin_memory=True,
)


def get_metrics(cfg: AudioEvaluationConfig):
"""Prepare a dictionary with metrics.
"""
available_metrics = ['sdr', 'sisdr', 'stoi', 'estoi', 'pesq']

metrics = dict()
for name in sorted(set(cfg.metrics)):
name = name.lower()
if name == 'sdr':
metric = AudioMetricWrapper(metric=SignalDistortionRatio())
elif name == 'sisdr':
metric = AudioMetricWrapper(metric=ScaleInvariantSignalDistortionRatio())
elif name == 'stoi':
metric = AudioMetricWrapper(metric=ShortTimeObjectiveIntelligibility(fs=cfg.sample_rate, extended=False))
elif name == 'estoi':
metric = AudioMetricWrapper(metric=ShortTimeObjectiveIntelligibility(fs=cfg.sample_rate, extended=True))
elif name == 'pesq':
metric = AudioMetricWrapper(metric=PerceptualEvaluationSpeechQuality(fs=cfg.sample_rate, mode='wb'))
else:
raise ValueError(f'Unexpected metric: {name}. Currently available metrics: {available_metrics}')

metrics[name] = metric

return metrics


@hydra_runner(config_name="AudioEvaluationConfig", schema=AudioEvaluationConfig)
def main(cfg: AudioEvaluationConfig):
torch.set_grad_enabled(False)

if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)

if cfg.audio_dir is not None:
raise RuntimeError(
"Evaluation script requires ground truth audio to be passed via a manifest file. "
"If manifest file is available, submit it via `dataset_manifest` argument."
)

if not os.path.exists(cfg.dataset_manifest):
raise FileNotFoundError(f'The dataset manifest file could not be found at path : {cfg.dataset_manifest}')

if cfg.target_dataset_dir is None:
# Assume the target data is available in the same directory as the input data
cfg.target_dataset_dir = os.path.dirname(cfg.dataset_manifest)
elif not os.path.isdir(cfg.target_dataset_dir):
raise FileNotFoundError(f'Target dataset dir could not be found at path : {cfg.target_dataset_dir}')

# Setup metrics
metrics = get_metrics(cfg)

# Processing
if not cfg.only_score_manifest:
# Process audio using the configured model and save in the output directory
process_cfg = process_audio.main(cfg) # type: ProcessConfig

# Release GPU memory if it was used during transcription
if torch.cuda.is_available():
torch.cuda.empty_cache()

logging.info('Finished processing audio.')
else:
# Score the input manifest, no need to run a model
cfg.output_filename = cfg.dataset_manifest
process_cfg = cfg

# Evaluation
with tempfile.TemporaryDirectory() as tmp_dir:
# Prepare a temporary manifest with processed audio and target
temporary_manifest_filepath = os.path.join(tmp_dir, 'manifest.json')

num_files = 0

with open(process_cfg.output_filename, 'r') as f_processed, open(
temporary_manifest_filepath, 'w', encoding='utf-8'
) as f_tmp:
for line_processed in f_processed:
data_processed = json.loads(line_processed)

if cfg.processed_key not in data_processed:
raise ValueError(
f'Processed key {cfg.processed_key} not found in manifest: {process_cfg.output_filename}.'
)

if cfg.target_key not in data_processed:
raise ValueError(
f'Target key {cfg.target_key} not found in manifest: {process_cfg.output_filename}.'
)

item = {
'processed': manifest.get_full_path(
audio_file=data_processed[cfg.processed_key], manifest_file=process_cfg.output_filename
),
'target': manifest.get_full_path(
audio_file=data_processed[cfg.target_key], data_dir=cfg.target_dataset_dir
),
'duration': data_processed.get('duration'),
}

# Double-check files exist
for key in ['processed', 'target']:
if not os.path.isfile(item[key]):
raise ValueError(f'File for key "{key}" not found at: {item[key]}.\nCurrent item: {item}')

# Warn if we're comparing the same files
if item['target'] == item['processed']:
logging.warning('Using the same file as processed and target: %s', item['target'])

# Write the entry in the temporary manifest file
f_tmp.write(json.dumps(item) + '\n')

num_files += 1

# Prepare dataloader
config = {
'manifest_filepath': temporary_manifest_filepath,
'sample_rate': cfg.sample_rate,
'input_key': 'processed',
'input_channel_selector': cfg.processed_channel_selector,
'target_key': 'target',
'target_channel_selector': cfg.target_channel_selector,
'batch_size': min(cfg.batch_size, num_files),
'num_workers': cfg.num_workers,
}
temporary_dataloader = get_evaluation_dataloader(config)

# Calculate metrics
for eval_batch in tqdm(temporary_dataloader, desc='Evaluating'):
processed_signal, processed_length, target_signal, target_length = eval_batch

if not torch.equal(processed_length, target_length):
raise RuntimeError(f'Length mismatch.')

for name, metric in metrics.items():
metric.update(preds=processed_signal, target=target_signal, input_length=target_length)

# Convert to a dictionary with name: value
metrics_value = {name: metric.compute().item() for name, metric in metrics.items()}

logging.info('Finished running evaluation.')

# Show results
logging.info('Summary\n')
logging.info('Data')
logging.info('\tmanifest: %s', cfg.output_filename)
logging.info('\ttarget_dataset_dir: %s', cfg.target_dataset_dir)
logging.info('\tnum_files: %s', num_files)
logging.info('Metrics')
for name, value in metrics_value.items():
logging.info('\t%10s: \t%6.2f', name, value)

# Inject the metric name and score into the config, and return the entire config
with open_dict(cfg):
cfg.metrics_value = metrics_value

return cfg


if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter
9 changes: 3 additions & 6 deletions examples/asr/audio_to_audio/process_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,11 @@ def autocast():

# if transcripts should not be overwritten, and already exists, skip re-transcription step and return
if not cfg.overwrite_output and os.path.exists(cfg.output_dir):
logging.warning(
raise RuntimeError(
f"Previous output found at {cfg.output_dir}, and flag `overwrite_output`"
f"is {cfg.overwrite_output}. Returning without processing."
)

return cfg

# Process audio
with autocast():
with torch.no_grad():
Expand All @@ -225,12 +223,11 @@ def autocast():
with open(cfg.dataset_manifest, 'r') as fr:
for idx, line in enumerate(fr):
item = json.loads(line)
item['audio_filepath_unprocessed'] = item[input_key]
item['audio_filepath'] = paths2processed_files[idx]
item['processed_audio_filepath'] = paths2processed_files[idx]
f.write(json.dumps(item) + "\n")
else:
for idx, processed_file in enumerate(paths2processed_files):
item = {'audio_filepath': processed_file}
item = {'processed_audio_filepath': processed_file}
f.write(json.dumps(item) + "\n")

return cfg
Expand Down
3 changes: 0 additions & 3 deletions nemo/collections/asr/metrics/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,6 @@ def update(self, preds: torch.Tensor, target: torch.Tensor, input_length: Option
target: tensor with target signals, shape (B, C, T)
input_length: Optional, input tensor with length (in samples) of each signal in the batch, shape (B,).
If not provided, it is assumed that all samples are valid.
Returns:
Underlying metric averaged on the current batch.
"""
preds, target = self._select_channel(preds=preds, target=target)

Expand Down
1 change: 0 additions & 1 deletion nemo/collections/asr/models/enhancement_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def process(
'input_channel_selector': input_channel_selector,
'batch_size': min(batch_size, len(paths2audio_files)),
'num_workers': num_workers,
'channel_selector': input_channel_selector,
}

# Create output dir if necessary
Expand Down
30 changes: 21 additions & 9 deletions nemo/collections/common/parts/preprocessing/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,12 @@ def __parse_item(line: str, manifest_file: str) -> Dict[str, Any]:
return item


def get_full_path(audio_file: str, manifest_file: str, audio_file_len_limit: int = 255) -> str:
def get_full_path(
audio_file: str,
manifest_file: Optional[str] = None,
data_dir: Optional[str] = None,
audio_file_len_limit: int = 255,
) -> str:
"""Get full path to audio_file.
If the audio_file is a relative path and does not exist,
Expand All @@ -169,24 +174,31 @@ def get_full_path(audio_file: str, manifest_file: str, audio_file_len_limit: int
Args:
audio_file: path to an audio file, either absolute or assumed relative
to the manifest directory
to the manifest directory or data directory
manifest_file: path to a manifest file
data_dir: path to a directory containing data, use only if a manifest file is not provided
audio_file_len_limit: limit for length of audio_file when using relative paths
Returns:
Full path to audio_file.
"""
audio_file = Path(audio_file)

if is_datastore_path(manifest_file):
# WORKAROUND: pathlib does not support URIs, so use os.path
manifest_dir = os.path.dirname(manifest_file)
else:
manifest_dir = Path(manifest_file).parent.as_posix()
if manifest_file is None and data_dir is None:
raise ValueError(f'Use either manifest_file or data_dir to specify the data directory.')
elif manifest_file is not None and data_dir is not None:
raise ValueError(f'Parameters manifest_file and data_dir cannot be used simultaneously.')

if data_dir is None:
if is_datastore_path(manifest_file):
# WORKAROUND: pathlib does not support URIs, so use os.path
data_dir = os.path.dirname(manifest_file)
else:
data_dir = Path(manifest_file).parent.as_posix()

if (len(str(audio_file)) < audio_file_len_limit) and not audio_file.is_file() and not audio_file.is_absolute():
# assume audio_file path is relative to manifest_dir
audio_file_path = os.path.join(manifest_dir, audio_file.as_posix())
# assume audio_file path is relative to data_dir
audio_file_path = os.path.join(data_dir, audio_file.as_posix())

if is_datastore_path(audio_file_path):
# If audio was originally on an object store, use locally-cached path
Expand Down

0 comments on commit 38bda00

Please sign in to comment.