Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ASR evaluator #5728

Merged
merged 38 commits into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
306291c
backbone
fayejf Nov 29, 2022
0d30528
engineer and analyzer
fayejf Dec 6, 2022
ff4db57
offline_by_chunked
fayejf Dec 21, 2022
bd3d4fd
test_ds wip
fayejf Dec 21, 2022
b8d0b6e
temp remove inference
fayejf Jan 3, 2023
707caa2
mandarin yaml
fayejf Jan 3, 2023
0fb9fa4
Merge branch 'main' into asr_evaluator_engine
fayejf Jan 3, 2023
94fa630
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2023
7ef4ac2
Merge branch 'main' into asr_evaluator_engine
fayejf Jan 5, 2023
11e163c
augmentor and a few updates
fayejf Jan 6, 2023
f40f6dc
Merge branch 'main' into asr_evaluator_engine
fayejf Jan 6, 2023
fdbfa20
Merge branch 'main' into asr_evaluator_engine
fayejf Jan 6, 2023
3e001ee
address alerts and revert unnecessary changes
fayejf Jan 6, 2023
52122bf
Add readme
fayejf Jan 6, 2023
86092ca
rename
fayejf Jan 6, 2023
bf5cfe8
typo fix
fayejf Jan 6, 2023
431877d
small fix
fayejf Jan 6, 2023
86fe931
Merge branch 'main' into asr_evaluator_engine
fayejf Jan 6, 2023
8af6afe
add missing header
fayejf Jan 6, 2023
b50483b
Merge branch 'main' into asr_evaluator_engine
fayejf Jan 9, 2023
9b1a52a
rename augmentor_config to augmentor
fayejf Jan 9, 2023
810adcf
raname inference_mode to inference
fayejf Jan 9, 2023
503fd04
move utils.py
fayejf Jan 9, 2023
ae50af2
update temp file
fayejf Jan 9, 2023
dcd842e
make wer cer clear
fayejf Jan 9, 2023
4f16213
seed_everything
fayejf Jan 10, 2023
d36309b
fix missing rn augmentor_config in rnnt
fayejf Jan 10, 2023
d435217
fix rnnt transcribe
fayejf Jan 10, 2023
1614acf
add more docstring and style fix
fayejf Jan 10, 2023
814910a
Merge branch 'main' into asr_evaluator_engine
fayejf Jan 10, 2023
6a0ad5a
Merge branch 'main' into asr_evaluator_engine
fayejf Jan 10, 2023
50473e1
address codeQL
fayejf Jan 10, 2023
00c7090
Merge branch 'main' into asr_evaluator_engine
fayejf Jan 10, 2023
840c8de
reflect comments
fayejf Jan 10, 2023
32f48e8
update readme
fayejf Jan 10, 2023
047710b
Merge branch 'main' into asr_evaluator_engine
fayejf Jan 10, 2023
70085ea
clearer
fayejf Jan 10, 2023
7533651
Merge branch 'main' into asr_evaluator_engine
stevehuang52 Jan 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from dataclasses import dataclass, is_dataclass
from typing import Optional

import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf

Expand Down Expand Up @@ -71,6 +72,7 @@ class TranscriptionConfig:
num_workers: int = 0
append_pred: bool = False # Sets mode of work, if True it will add new field transcriptions.
pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one.
random_seed: Optional[int] = None # seed number going to be used in seed_everything()

# Chunked configs
chunk_len_in_secs: float = 1.6 # Chunk length in seconds
Expand All @@ -96,6 +98,9 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)

if cfg.random_seed:
pl.seed_everything(cfg.random_seed)

if cfg.model_path is None and cfg.pretrained_name is None:
raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!")
if cfg.audio_dir is None and cfg.dataset_manifest is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from dataclasses import dataclass, is_dataclass
from typing import Optional

import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf, open_dict

Expand Down Expand Up @@ -95,6 +96,7 @@ class TranscriptionConfig:
num_workers: int = 0
append_pred: bool = False # Sets mode of work, if True it will add new field transcriptions.
pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one.
random_seed: Optional[int] = None # seed number going to be used in seed_everything()

# Chunked configs
chunk_len_in_secs: float = 1.6 # Chunk length in seconds
Expand Down Expand Up @@ -127,6 +129,9 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)

if cfg.random_seed:
pl.seed_everything(cfg.random_seed)

if cfg.model_path is None and cfg.pretrained_name is None:
raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!")
if cfg.audio_dir is None and cfg.dataset_manifest is None:
Expand Down
15 changes: 15 additions & 0 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,15 @@ class TranscriptionConfig:
dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest
channel_selector: Optional[int] = None # Used to select a single channel from multi-channel files
audio_key: str = 'audio_filepath' # Used to override the default audio key in dataset_manifest
eval_config_yaml: Optional[str] = None # Path to a yaml file of config of evaluation

# General configs
output_filename: Optional[str] = None
batch_size: int = 32
num_workers: int = 0
append_pred: bool = False # Sets mode of work, if True it will add new field transcriptions.
pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one.
random_seed: Optional[int] = None # seed number going to be used in seed_everything()

# Set to True to output greedy timestamp information (only supported models)
compute_timestamps: bool = False
Expand Down Expand Up @@ -152,11 +154,21 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)

if cfg.random_seed:
pl.seed_everything(cfg.random_seed)

if cfg.model_path is None and cfg.pretrained_name is None:
raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!")
if cfg.audio_dir is None and cfg.dataset_manifest is None:
raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!")

# Load augmentor from exteranl yaml file which contains eval info, could be extend to other feature such VAD, P&C
augmentor = None
if cfg.eval_config_yaml:
eval_config = OmegaConf.load(cfg.eval_config_yaml)
augmentor = eval_config.test_ds.get("augmentor")
logging.info(f"Will apply on-the-fly augmentation on samples during transcription: {augmentor} ")

# setup GPU
if cfg.cuda is None:
if torch.cuda.is_available():
Expand Down Expand Up @@ -253,6 +265,7 @@ def autocast():
num_workers=cfg.num_workers,
return_hypotheses=return_hypotheses,
channel_selector=cfg.channel_selector,
augmentor=augmentor,
)
else:
logging.warning(
Expand All @@ -264,6 +277,7 @@ def autocast():
num_workers=cfg.num_workers,
return_hypotheses=return_hypotheses,
channel_selector=cfg.channel_selector,
augmentor=augmentor,
)
else:
transcriptions = asr_model.transcribe(
Expand All @@ -272,6 +286,7 @@ def autocast():
num_workers=cfg.num_workers,
return_hypotheses=return_hypotheses,
channel_selector=cfg.channel_selector,
augmentor=augmentor,
)

logging.info(f"Finished transcribing {len(filepaths)} files !")
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/asr/models/ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'use_start_end_token': self.cfg.validation_ds.get('use_start_end_token', False),
}

if config.get("augmentor"):
dl_config['augmentor'] = config.get("augmentor")

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
return temporary_datalayer

Expand Down
11 changes: 10 additions & 1 deletion nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,12 @@ def transcribe(
return_hypotheses: bool = False,
num_workers: int = 0,
channel_selector: Optional[ChannelSelectorType] = None,
augmentor: DictConfig = None,
) -> List[str]:
"""
If modify this function, please remember update transcribe_partial_audio() in
nemo/collections/asr/parts/utils/trancribe_utils.py

Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.

Args:
Expand All @@ -131,7 +135,7 @@ def transcribe(
With hypotheses can do some postprocessing like getting timestamp or rescoring
num_workers: (int) number of workers for DataLoader
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`.

augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied.
Returns:
A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files
"""
Expand Down Expand Up @@ -182,6 +186,9 @@ def transcribe(
'channel_selector': channel_selector,
}

if augmentor:
config['augmentor'] = augmentor

temporary_datalayer = self._setup_transcribe_dataloader(config)
for test_batch in tqdm(temporary_datalayer, desc="Transcribing"):
logits, logits_len, greedy_predictions = self.forward(
Expand Down Expand Up @@ -724,6 +731,8 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'pin_memory': True,
'channel_selector': config.get('channel_selector', None),
}
if config.get("augmentor"):
dl_config['augmentor'] = config.get("augmentor")

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
return temporary_datalayer
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/asr/models/rnnt_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,5 +579,8 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'use_start_end_token': self.cfg.validation_ds.get('use_start_end_token', False),
}

if config.get("augmentor"):
dl_config['augmentor'] = config.get("augmentor")

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
return temporary_datalayer
9 changes: 8 additions & 1 deletion nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def transcribe(
partial_hypothesis: Optional[List['Hypothesis']] = None,
num_workers: int = 0,
channel_selector: Optional[ChannelSelectorType] = None,
augmentor: DictConfig = None,
) -> Tuple[List[str], Optional[List['Hypothesis']]]:
"""
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Expand All @@ -232,7 +233,7 @@ def transcribe(
With hypotheses can do some postprocessing like getting timestamp or rescoring
num_workers: (int) number of workers for DataLoader
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.

augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied.
Returns:
A list of transcriptions in the same order as paths2audio_files. Will also return
"""
Expand Down Expand Up @@ -277,6 +278,9 @@ def transcribe(
'channel_selector': channel_selector,
}

if augmentor:
config['augmentor'] = augmentor

temporary_datalayer = self._setup_transcribe_dataloader(config)
for test_batch in tqdm(temporary_datalayer, desc="Transcribing"):
encoded, encoded_len = self.forward(
Expand Down Expand Up @@ -938,6 +942,9 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'pin_memory': True,
}

if config.get("augmentor"):
dl_config['augmentor'] = config.get("augmentor")

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
return temporary_datalayer

Expand Down
7 changes: 6 additions & 1 deletion nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,13 +334,16 @@ def write_transcription(

def transcribe_partial_audio(
asr_model,
path2manifest: str,
path2manifest: str = None,
batch_size: int = 4,
logprobs: bool = False,
return_hypotheses: bool = False,
num_workers: int = 0,
channel_selector: Optional[int] = None,
augmentor: DictConfig = None,
) -> List[str]:
"""
See description of this function in trancribe() in nemo/collections/asr/models/ctc_models.py """

assert isinstance(asr_model, EncDecCTCModel), "Currently support CTC model only."

Expand Down Expand Up @@ -377,6 +380,8 @@ def transcribe_partial_audio(
'num_workers': num_workers,
'channel_selector': channel_selector,
}
if augmentor:
config['augmentor'] = augmentor

temporary_datalayer = asr_model._setup_transcribe_dataloader(config)
for test_batch in tqdm(temporary_datalayer, desc="Transcribing"):
Expand Down
44 changes: 44 additions & 0 deletions tools/asr_evaluator/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
ASR evaluator
--------------------

A tool for thoroughly evaluating the performance of ASR models and other features such as Voice Activity Detection.

Features:
- Simple step to evaluate a model in all three modes currently supported by NeMo: offline, chunked, and offline_by_chunked.
- On-the-fly data augmentation (such as silence, noise, etc.,) for ASR robustness evaluation.
- Investigate the model's performance by detailed insertion, deletion, and substitution error rates for each and all samples.
- Evaluate models' reliability on different target groups such as gender, and audio length if metadata is presented.


ASR evaluator contains two main parts:
- **ENGINE**. To conduct ASR inference.
- **ANALYST**. To evaluate model performance based on predictions.

In Analyst, we can evaluate on metadata (such as duration, emotion, etc.) if it presents in manifest. For example, with the following config, we can calculate WERs for audios in different interval groups, where each group (in seconds) is defined by [[0,2],[2,5],[5,10],[10,20],[20,100000]]. Also, we can calculate the WERs for three groups of emotions, where each group is defined by [['happy','laugh'],['neutral'],['sad']]. Moreover, if we set save_wer_per_class=True, it will calculate WERs for audios in all classes presented in the data (i.e. above 5 classes + 'cry' which presented in data but not in the slot).

```
analyst:
metadata:
duration:
enable: True
slot: [[0,2],[2,5],[5,10],[10,20],[20,100000]]
save_wer_per_class: False # whether to save wer for each presented class.

emotion:
enable: True
slot: [['happy','laugh'],['neutral'],['sad']] # we could have 'cry' in data but not in slot we focus on.
save_wer_per_class: False
```


Check `./conf/eval.yaml` for the supported configuration.

If you plan to evaluate/add new tasks such as Punctuation and Capitalization, add it to the engine.

Run
```
python asr_evaluator.py \
engine.pretrained_name="stt_en_conformer_transducer_large" \
engine.inference_mode.mode="offline" \
engine.test_ds.augmentor.noise.manifest_path=<manifest file for noise data>
```
92 changes: 92 additions & 0 deletions tools/asr_evaluator/asr_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json

import git
from omegaconf import OmegaConf
from utils import cal_target_metadata_wer, cal_write_wer, run_asr_inference

from nemo.core.config import hydra_runner
from nemo.utils import logging


"""
This script serves as evaluator of ASR models
Usage:
python python asr_evaluator.py \
engine.pretrained_name="stt_en_conformer_transducer_large" \
engine.inference.mode="offline" \
engine.test_ds.augmentor.noise.manifest_path=<manifest file for noise data> \
.....

Check out parameters in ./conf/eval.yaml
stevehuang52 marked this conversation as resolved.
Show resolved Hide resolved
"""


@hydra_runner(config_path="conf", config_name="eval.yaml")
def main(cfg):
report = {}
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

# Store git hash for reproducibility
if cfg.env.save_git_hash:
repo = git.Repo(search_parent_directories=True)
report['git_hash'] = repo.head.object.hexsha

## Engine
# Could skip next line to use generated manifest

# If need to change more parameters for ASR inference, change it in
# 1) shell script in eval_utils.py in nemo/collections/asr/parts/utils or
# 2) TranscriptionConfig on top of the executed scripts such as transcribe_speech.py in examples/asr
cfg.engine = run_asr_inference(cfg=cfg.engine)

## Analyst
cfg, total_res, eval_metric = cal_write_wer(cfg)
report.update({"res": total_res})

for target in cfg.analyst.metadata:
if cfg.analyst.metadata[target].enable:
occ_avg_wer = cal_target_metadata_wer(
manifest=cfg.analyst.metric_calculator.output_filename,
target=target,
meta_cfg=cfg.analyst.metadata[target],
eval_metric=eval_metric,
)
report[target] = occ_avg_wer

config_engine = OmegaConf.to_object(cfg.engine)
report.update(config_engine)

config_metric_calculator = OmegaConf.to_object(cfg.analyst.metric_calculator)
report.update(config_metric_calculator)

pretty = json.dumps(report, indent=4)
res = "%.3f" % (report["res"][eval_metric] * 100)
logging.info(pretty)
logging.info(f"Overall {eval_metric} is {res} %")

## Writer
report_file = "report.json"
if "report_filename" in cfg.writer and cfg.writer.report_filename:
report_file = cfg.writer.report_filename

with open(report_file, "a") as fout:
json.dump(report, fout)
fout.write('\n')
fout.flush()


if __name__ == "__main__":
main()
Loading