Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create per.py #7538

Merged
merged 153 commits into from
Oct 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
153 commits
Select commit Hold shift + click to select a range
9c461ea
Move model precision copy (#7336)
maanug-nv Sep 7, 2023
59802b1
Fix PEFT checkpoint loading (#7388)
blahBlahhhJ Sep 7, 2023
0d97b6c
Use distributed optimizer support for multiple dtypes (#7359)
timmoon10 Sep 7, 2023
201cccc
minor fix for llama ckpt conversion script (#7387)
blahBlahhhJ Sep 7, 2023
869240d
Fix wrong calling of librosa.get_duration() in notebook (#7376)
RobinDong Sep 8, 2023
230146f
[PATCH] PEFT import mcore (#7393)
blahBlahhhJ Sep 8, 2023
c2c7d41
Create per.py
ssh-meister Sep 27, 2023
c7ba7ca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2023
6723bda
[TTS] Added a callback for logging initial data (#7384)
anteju Sep 8, 2023
0fffeef
Update Core Commit (#7402)
aklife97 Sep 9, 2023
3eb95a4
Use cfg attribute in bert (#7394)
maanug-nv Sep 9, 2023
834a5c7
Add support for bias conversion in Swiglu models (#7386)
titu1994 Sep 9, 2023
c871147
Update save_to and restore_from for dist checkpointing (#7343)
ericharper Sep 9, 2023
3eed031
fix forward for with mcore=false (#7403)
JimmyZhang12 Sep 9, 2023
6093209
Fix logging to remove 's/it' from progress bar in Megatron models and…
athitten Sep 9, 2023
592282f
Set Activation Checkpointing Defaults (#7404)
aklife97 Sep 9, 2023
00246c9
make loss mask default to false (#7407)
ericharper Sep 9, 2023
0dac83d
Add dummy userbuffer config files (#7408)
erhoo82 Sep 9, 2023
bb7cd82
add missing ubconf files (#7412)
aklife97 Sep 11, 2023
b92cc7c
New tutorial on Speech Data Explorer (#7405)
Jorjeous Sep 11, 2023
e157cd0
Update ptl training ckpt conversion script to work with dist ckpt (#7…
ericharper Sep 12, 2023
1f28287
Allow disabling sanity checking when num_sanity_val_steps=0 (#7413)
athitten Sep 12, 2023
41e664e
Add comprehensive error messages (#7261)
PeganovAnton Sep 12, 2023
f6fc39a
check NEMO_PATH (#7418)
karpnv Sep 12, 2023
2147faa
layer selection for ia3 (#7417)
arendu Sep 13, 2023
749164a
Fix missing pip package 'einops' (#7397)
RobinDong Sep 14, 2023
952b2a4
Fix failure of pyaudio in Google Colab (#7396)
RobinDong Sep 15, 2023
14ba7f8
Update README.md: output_path --> output_manifest_filepath (#7442)
popcornell Sep 18, 2023
32dc1d0
Add rope dynamic linear scaling (#7437)
hsiehjackson Sep 18, 2023
48c25be
Fix None dataloader issue in PTL2.0 (#7455)
KunalDhawan Sep 19, 2023
bafbdc6
[ASR] Confidence measure -> method renames (#7434)
GNroy Sep 19, 2023
1660781
Add steps for document of getting dataset 'SF Bilingual Speech' (#7378)
RobinDong Sep 19, 2023
cffe476
RNN-T confidence and alignment bugfix (#7381)
GNroy Sep 19, 2023
f634e0c
Fix resume from checkpoint in exp_manager (#7424) (#7426)
github-actions[bot] Sep 19, 2023
569dabc
Fix checking of cuda/cpu device for inputs of Decoder (#7444)
RobinDong Sep 19, 2023
edb95cd
Fix failure of ljspeech's get_data.py (#7430)
RobinDong Sep 19, 2023
1bd4bd0
[TTS] Fix audio codec type checks (#7373)
rlangman Sep 19, 2023
d22b4d1
[TTS] Add dataset to path of logged artifacts (#7462)
rlangman Sep 20, 2023
bd9e53f
Fix sft dataset truncation (#7464)
hsiehjackson Sep 20, 2023
b386f5b
Automatic Lip Reading Recognition (ALR) - ASR/CV (Visual ASR) (#7330)
burchim Sep 20, 2023
0140e23
HF StarCoder to NeMo conversion script (#7421)
janekl Sep 20, 2023
24d2e50
fix bug when loading dist ckpt in peft (#7452)
lhb8125 Sep 21, 2023
6e83a05
Fix adding positional embeddings in-place in transformer module (#7440)
The0nix Sep 21, 2023
63a08df
Fix (#7478)
hsiehjackson Sep 22, 2023
d53b88b
add sleep (#7498) (#7499)
github-actions[bot] Sep 24, 2023
d36dea1
Fix exp manager check for sleep (#7503) (#7504)
github-actions[bot] Sep 25, 2023
660d8e4
bugfix: trainer.accelerator=auto from None. (#7492) (#7493)
github-actions[bot] Sep 25, 2023
0a556b7
[doc] fix broken link (#7481)
stas00 Sep 25, 2023
952b768
[TTS] Read audio as int32 to avoid flac read errors (#7477)
rlangman Sep 26, 2023
43df79d
Add dataset 'AISHELL-3' from OpenSLR for training mandarin TTS (#7409)
RobinDong Sep 26, 2023
6d2dcca
dllogger - log on rank 0 only (#7513)
stas00 Sep 26, 2023
a948ce6
Fix TTS FastPitch tutorial (#7494) (#7516)
github-actions[bot] Sep 26, 2023
0d86cad
Fix get_dist() tensor dimension (#7506) (#7515)
github-actions[bot] Sep 26, 2023
5195b49
bugfix: specify trainer.strategy=auto when devices=1 (#7509) (#7512)
github-actions[bot] Sep 26, 2023
4960b24
fix (#7511)
aklife97 Sep 26, 2023
5a252c6
[TTS] Fix FastPitch data prep tutorial (#7524)
rlangman Sep 27, 2023
89b00c1
add italian tokenization (#7486)
GiacomoLeoneMaria Sep 27, 2023
52a2486
Replace None strategy with auto in tutorial notebooks (#7521) (#7527)
github-actions[bot] Sep 27, 2023
04d143b
unpin setuptools (#7534) (#7535)
github-actions[bot] Sep 27, 2023
2854d39
Update per.py
ssh-meister Sep 28, 2023
2f0d6b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 28, 2023
6ff3768
Create punctuation_rates.py
ssh-meister Sep 29, 2023
89504b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2023
85b87c0
Format fixing
ssh-meister Oct 2, 2023
21a1ed7
added nemo.logging, header, docstrings, how to use
ssh-meister Oct 2, 2023
6a05553
Added asserions to rate_punctuation.py
ssh-meister Oct 2, 2023
8d0abc2
fix typo
ssh-meister Oct 2, 2023
a5bc120
added function for import and call, docstrings
ssh-meister Oct 2, 2023
3627bea
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 2, 2023
250b031
remove auto generated examples (#7510)
arendu Sep 27, 2023
05c8bf6
Add the `strategy` argument to `MegatronGPTModel.generate()` (#7264)
odelalleau Sep 27, 2023
c847f0b
Fix PTL2.0 related ASR bugs in r1.21.0: Val metrics logging, None dat…
github-actions[bot] Sep 27, 2023
258a159
gpus -> devices (#7542) (#7545)
github-actions[bot] Sep 28, 2023
061f9e9
Update FFMPEG version to fix issue with torchaudio (#7551) (#7553)
github-actions[bot] Sep 28, 2023
1291e2a
PEFT GPT & T5 Refactor (#7308)
meatybobby Sep 28, 2023
9352c7d
fix a typo (#7496)
BestJuly Sep 28, 2023
0f2802e
[TTS] remove curly braces from ${BRANCH} in jupyer notebook cell. (#7…
github-actions[bot] Sep 28, 2023
cf6f95f
add youtube embed url (#7570)
XuesongYang Sep 29, 2023
9b19b68
Remap speakers to continuous range of speaker_id for dataset AISHELL3…
RobinDong Sep 29, 2023
5a4bff0
fix validation_step_outputs initialization for multi-dataloader (#754…
github-actions[bot] Sep 29, 2023
edd3490
Append output of val step to self.validation_step_outputs (#7530) (#7…
github-actions[bot] Sep 29, 2023
bfbe627
[TTS] fixed trainer's accelerator and strategy. (#7569) (#7574)
github-actions[bot] Sep 29, 2023
1ce2455
Append val/test output to instance variable in EncDecSpeakerLabelMode…
github-actions[bot] Sep 29, 2023
c7f4ecb
Fix CustomProgressBar for resume (#7427) (#7522)
github-actions[bot] Sep 30, 2023
9811729
fix typos in nfa and speech enhancement tutorials (#7580) (#7583)
github-actions[bot] Sep 30, 2023
a6217ea
Add strategy as ddp_find_unused_parameters_true for glue_benchmark.py…
github-actions[bot] Sep 30, 2023
aeac6e8
update strategy (#7577) (#7578)
github-actions[bot] Sep 30, 2023
71bd302
Fix typos (#7581)
Kipok Oct 2, 2023
d062641
Change hifigan finetune strategy to ddp_find_unused_parameters_true (…
github-actions[bot] Oct 2, 2023
0352f30
[BugFix] Add missing quotes for auto strategy in tutorial notebooks (…
github-actions[bot] Oct 2, 2023
c76afc0
added per tests
ssh-meister Oct 2, 2023
3f9b7bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 2, 2023
8baa297
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 2, 2023
6c74c25
[PATCH] PEFT import mcore (#7393)
blahBlahhhJ Sep 8, 2023
f3d58b1
add build os key (#7596) (#7599)
github-actions[bot] Oct 2, 2023
deb80c4
StarCoder SFT test + bump PyT NGC image to 23.09 (#7540)
janekl Oct 2, 2023
620c011
defaults changed (#7600)
arendu Oct 3, 2023
8c892db
add ItalianPhonemesTokenizer (#7587)
GiacomoLeoneMaria Oct 3, 2023
96ec7ef
best ckpt fix (#7564) (#7588)
github-actions[bot] Oct 3, 2023
518b870
rate_punctuation.py
ssh-meister Oct 3, 2023
c8698d0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 3, 2023
98160b8
Fix tests
ssh-meister Oct 3, 2023
dadf28e
Add files via upload (#7598)
Jorjeous Oct 3, 2023
a4702d9
Fix validation in G2PModel and ThutmoseTaggerModel (#7597) (#7606)
github-actions[bot] Oct 3, 2023
74d9b63
Broadcast loss only when using pipeline parallelism and within the pi…
github-actions[bot] Oct 3, 2023
098c565
Safeguard nemo_text_processing installation on ARM (#7485)
blisc Oct 3, 2023
0bb19db
Function name fixing
ssh-meister Oct 4, 2023
fb13052
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 4, 2023
25f9d92
Moving PER to speech_to_text_eval.py
ssh-meister Oct 4, 2023
2c832fd
Update test_metrics.py
ssh-meister Oct 4, 2023
ce75174
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 4, 2023
8600bc5
Added use_per description
ssh-meister Oct 4, 2023
5a2cae1
guard extra dependencies
ssh-meister Oct 4, 2023
efaa9aa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 4, 2023
db6d1cb
Write metrics to "output_filename" if "scores_per_sample=True"
ssh-meister Oct 4, 2023
873aa84
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 4, 2023
9ffbdd4
scores_per_sample description
ssh-meister Oct 4, 2023
9de6049
Fix import guards
ssh-meister Oct 4, 2023
1c0ceac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 4, 2023
ac73188
Stats printing when HAVE_TABLUATE_AND_PANDAS=False
ssh-meister Oct 4, 2023
1a752e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 4, 2023
33b0142
Bound transformers version in requirements (#7620)
athitten Oct 4, 2023
047d1cc
fix llama2 70b lora tuning bug (#7622)
cuichenx Oct 4, 2023
4552cb5
Fix import error no module name model_utils (#7629)
menon92 Oct 4, 2023
0322a9c
Delete examples/asr/rate_punctuation.py
ssh-meister Oct 4, 2023
3ce067e
Added use_per description
ssh-meister Oct 4, 2023
6bb7d52
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 4, 2023
9f46c92
metric and variables name fixing
ssh-meister Oct 5, 2023
3be7fbf
Add else samples = None
ssh-meister Oct 5, 2023
baef202
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 5, 2023
501d656
add fc large ls models (#7641)
nithinraok Oct 4, 2023
ee17a52
bugfix: trainer.gpus, trainer.strategy, trainer.accelerator (#7621) (…
github-actions[bot] Oct 5, 2023
d958251
fix ssl models ptl monitor val through logging (#7608) (#7614)
github-actions[bot] Oct 5, 2023
6f010c3
Fix metrics for SE tutorial (#7604) (#7612)
github-actions[bot] Oct 5, 2023
dd1b6c7
Add ddp_find_unused_parameters=True and change accelerator to auto (#…
github-actions[bot] Oct 5, 2023
32181a6
Fix py3.11 dataclasses issue (#7616)
github-actions[bot] Oct 5, 2023
d703b1f
moved per sample metrics computing to transcribe_utils
ssh-meister Oct 5, 2023
735d8c3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 5, 2023
36139e2
Moved punctuation rates printing to punct_er
ssh-meister Oct 5, 2023
f8bd001
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 5, 2023
cca7156
Added reset for DatasetPunctuationErrorRate class
ssh-meister Oct 5, 2023
4df6be3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 5, 2023
03589b7
Added compute_metrics_per_sample description
ssh-meister Oct 5, 2023
aa92203
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 5, 2023
69b166a
Merge branch 'main' into per
ssh-meister Oct 5, 2023
bb79ded
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 5, 2023
9a2018c
Update megatron_gpt_peft_models.py
ssh-meister Oct 5, 2023
fe0961b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 5, 2023
e74ef5d
Update speech_to_text_eval.py
ssh-meister Oct 5, 2023
04044ab
Copyright year fixing
ssh-meister Oct 6, 2023
7fb390b
"& AFFILIATES" added
ssh-meister Oct 6, 2023
57f4347
Merge branch 'main' into per
ekmb Oct 6, 2023
fbf3a70
Merge branch 'main' into per
ssh-meister Oct 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions examples/asr/speech_to_text_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,18 @@
for full list of arguments >>

dataset_manifest: Required - path to dataset JSON manifest file (in NeMo format)
output_filename: Optional - output filename where the transcriptions will be written.
output_filename: Optional - output filename where the transcriptions will be written. (if scores_per_sample=True,
metrics per sample will be written there too)

use_cer: Bool, whether to compute CER or WER
use_punct_er: Bool, compute dataset Punctuation Error Rate (set the punctuation marks for metrics computation with
"text_processing.punctuation_marks")

tolerance: Float, minimum WER/CER required to pass some arbitrary tolerance.

only_score_manifest: Bool, when set will skip audio transcription and just calculate WER of provided manifest.
scores_per_sample: Bool, compute metrics for each sample separately (if only_score_manifest=True, scores per sample
will be added to the manifest at the dataset_manifest path)

# Usage

Expand Down Expand Up @@ -66,7 +72,12 @@
from omegaconf import MISSING, OmegaConf, open_dict

from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization, TextProcessingConfig
from nemo.collections.asr.parts.utils.transcribe_utils import (
PunctuationCapitalization,
TextProcessingConfig,
compute_metrics_per_sample,
)
from nemo.collections.common.metrics.punct_er import DatasetPunctuationErrorRate
from nemo.core.config import hydra_runner
from nemo.utils import logging

Expand All @@ -82,9 +93,11 @@
att_context_size: Optional[list] = None

use_cer: bool = False
use_punct_er: bool = False
tolerance: Optional[float] = None

only_score_manifest: bool = False
scores_per_sample: bool = False

text_processing: Optional[TextProcessingConfig] = TextProcessingConfig(
punctuation_marks=".,?", separate_punctuation=False, do_lowercase=False, rm_punctuation=False,
Expand Down Expand Up @@ -154,6 +167,29 @@
f"contain value for `pred_text`."
)

if cfg.use_punct_er:
dper_obj = DatasetPunctuationErrorRate(
hypotheses=predicted_text,
references=ground_truth_text,
punctuation_marks=list(cfg.text_processing.punctuation_marks),
ssh-meister marked this conversation as resolved.
Show resolved Hide resolved
)
dper_obj.compute()

if cfg.scores_per_sample:
ssh-meister marked this conversation as resolved.
Show resolved Hide resolved
metrics_to_compute = ["wer", "cer"]

if cfg.use_punct_er:
metrics_to_compute.append("punct_er")

samples_with_metrics = compute_metrics_per_sample(

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable samples_with_metrics is not used.
manifest_path=cfg.dataset_manifest,
reference_field="text",
hypothesis_field="pred_text",
metrics=metrics_to_compute,
punctuation_marks=cfg.text_processing.punctuation_marks,
output_manifest_path=cfg.output_filename,
)

# Compute the WER
cer = word_error_rate(hypotheses=predicted_text, references=ground_truth_text, use_cer=True)
wer = word_error_rate(hypotheses=predicted_text, references=ground_truth_text, use_cer=False)
Expand All @@ -173,6 +209,10 @@

logging.info(f'Dataset WER/CER ' + str(round(100 * wer, 2)) + "%/" + str(round(100 * cer, 2)) + "%")

if cfg.use_punct_er:
dper_obj.print()
dper_obj.reset()

# Inject the metric name and score into the config, and return the entire config
with open_dict(cfg):
cfg.metric_name = metric_name
Expand Down
92 changes: 92 additions & 0 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
from tqdm.auto import tqdm

import nemo.collections.asr as nemo_asr
from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.collections.asr.models import ASRModel, EncDecHybridRNNTCTCModel
from nemo.collections.asr.parts.utils import rnnt_utils
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR
from nemo.collections.common.metrics.punct_er import OccurancePunctuationErrorRate
from nemo.collections.common.parts.preprocessing.manifest import get_full_path
from nemo.utils import logging, model_utils

Expand Down Expand Up @@ -472,6 +474,96 @@ def transcribe_partial_audio(
return hypotheses


def compute_metrics_per_sample(
manifest_path: str,
reference_field: str = "text",
hypothesis_field: str = "pred_text",
metrics: list[str] = ["wer"],
punctuation_marks: list[str] = [".", ",", "?"],
output_manifest_path: str = None,
) -> dict:

'''
Computes metrics per sample for given manifest

Args:
manifest_path: str, Required - path to dataset JSON manifest file (in NeMo format)
reference_field: str, Optional - name of field in .json manifest with the reference text ("text" by default).
hypothesis_field: str, Optional - name of field in .json manifest with the hypothesis text ("pred_text" by default).
metrics: list[str], Optional - list of metrics to be computed (currently supported "wer", "cer", "punct_er")
punctuation_marks: list[str], Optional - list of punctuation marks for computing punctuation error rate ([".", ",", "?"] by default).
output_manifest_path: str, Optional - path where .json manifest with calculated metrics will be saved.

Returns:
samples: dict - Dict of samples with calculated metrics
'''

supported_metrics = ["wer", "cer", "punct_er"]

if len(metrics) == 0:
raise AssertionError(
f"'metrics' list is empty. \
Select the metrics from the supported: {supported_metrics}."
)

for metric in metrics:
if metric not in supported_metrics:
raise AssertionError(
f"'{metric}' metric is not supported. \
Currently supported metrics are {supported_metrics}."
)

if "punct_er" in metrics:
if len(punctuation_marks) == 0:
raise AssertionError("punctuation_marks list can't be empty when 'punct_er' metric is enabled.")
else:
oper_obj = OccurancePunctuationErrorRate(punctuation_marks=punctuation_marks)

use_wer = "wer" in metrics
use_cer = "cer" in metrics
use_punct_er = "punct_er" in metrics

with open(manifest_path, 'r') as manifest:
lines = manifest.readlines()
samples = [json.loads(line) for line in lines]
samples_with_metrics = []

logging.info(f"Computing {', '.join(metrics)} per sample")

for sample in tqdm(samples):
reference = sample[reference_field]
hypothesis = sample[hypothesis_field]

if use_wer:
sample_wer = word_error_rate(hypotheses=[hypothesis], references=[reference], use_cer=False)
sample["wer"] = round(100 * sample_wer, 2)

if use_cer:
sample_cer = word_error_rate(hypotheses=[hypothesis], references=[reference], use_cer=True)
sample["cer"] = round(100 * sample_cer, 2)

if use_punct_er:
operation_amounts, substitution_amounts, punctuation_rates = oper_obj.compute(
reference=reference, hypothesis=hypothesis
)
sample["punct_correct_rate"] = round(100 * punctuation_rates.correct_rate, 2)
sample["punct_deletions_rate"] = round(100 * punctuation_rates.deletions_rate, 2)
sample["punct_insertions_rate"] = round(100 * punctuation_rates.insertions_rate, 2)
sample["punct_substitutions_rate"] = round(100 * punctuation_rates.substitutions_rate, 2)
sample["punct_error_rate"] = round(100 * punctuation_rates.punct_er, 2)

samples_with_metrics.append(sample)

if output_manifest_path is not None:
with open(output_manifest_path, 'w') as output:
for sample in samples_with_metrics:
line = json.dumps(sample)
output.writelines(f'{line}\n')
logging.info(f'Output manifest saved: {output_manifest_path}')

return samples_with_metrics


class PunctuationCapitalization:
def __init__(self, punctuation_marks: str):
"""
Expand Down
Loading
Loading