From fcb7aa95c1f5015d3c092249e2f7c11dc8da978d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 Nov 2023 14:54:37 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tools/speech_data_explorer/data_explorer.py | 53 ++--- .../sde/dataloader/dataset.py | 198 +++++++++++------- .../sde/dataloader/engines/cudf_engine.py | 31 +-- .../sde/dataloader/sample.py | 117 ++++++----- 4 files changed, 233 insertions(+), 166 deletions(-) diff --git a/tools/speech_data_explorer/data_explorer.py b/tools/speech_data_explorer/data_explorer.py index 264c191217a7..b3193a03b535 100755 --- a/tools/speech_data_explorer/data_explorer.py +++ b/tools/speech_data_explorer/data_explorer.py @@ -121,10 +121,7 @@ def parse_args(): help='field name for which you want to see statistics (optional). Example: pred_text_contextnet.', ) parser.add_argument( - '--gpu', - '-gpu', - action='store_true', - help='use GPU-acceleration', + '--gpu', '-gpu', action='store_true', help='use GPU-acceleration', ) args = parser.parse_args() @@ -490,7 +487,7 @@ def plot_histogram(data, key, label, gpu_acceleration=False): data_frame = data[key].to_list() else: data_frame = [item[key] for item in data] - + fig = px.histogram( data_frame=data_frame, nbins=50, @@ -504,10 +501,10 @@ def plot_histogram(data, key, label, gpu_acceleration=False): return fig -def plot_word_accuracy(vocabulary_data): +def plot_word_accuracy(vocabulary_data): labels = ['Unrecognized', 'Sometimes recognized', 'Always recognized'] counts = [0, 0, 0] - + if args.gpu: counts[0] = (vocabulary_data['Accuracy'] == 0).sum() counts[1] = (vocabulary_data['Accuracy'] < 100).sum() @@ -576,24 +573,27 @@ def absolute_audio_filepath(audio_filepath, audio_base_path): if args.gpu: if args.names_compared is not None: raise Exception(f"Currently comparision mode is not available with gpu acceleation.") - + hypothesis_fields = ["pred_text"] if args.show_statistics is not None: hypothesis_fields = [args.show_statistics] - + enable_plk = True if args.disable_caching_metrics: enable_plk = False - + cu_df = cuDF() - dataset = Dataset(manifest_filepath = args.manifest, data_engine = cu_df, - hypothesis_fields = hypothesis_fields, - estimate_audio_metrics = args.estimate_audio_metrics, - enable_plk = enable_plk) - + dataset = Dataset( + manifest_filepath=args.manifest, + data_engine=cu_df, + hypothesis_fields=hypothesis_fields, + estimate_audio_metrics=args.estimate_audio_metrics, + enable_plk=enable_plk, + ) + dataset = dataset.process() - + data = dataset.samples_data num_hours = dataset.duration vocabulary = dataset.vocabulary_data @@ -602,8 +602,8 @@ def absolute_audio_filepath(audio_filepath, audio_base_path): metrics_available = len(dataset.hypotheses) != 0 if metrics_available: wer = dataset.hypotheses[hypothesis_fields[0]].wer - cer = dataset.hypotheses[hypothesis_fields[0]].cer - wmr = dataset.hypotheses[hypothesis_fields[0]].wmr + cer = dataset.hypotheses[hypothesis_fields[0]].cer + wmr = dataset.hypotheses[hypothesis_fields[0]].wmr mwa = dataset.hypotheses[hypothesis_fields[0]].mwa else: if not comparison_mode: @@ -706,7 +706,7 @@ def absolute_audio_filepath(audio_filepath, audio_base_path): figure_word_acc = plot_word_accuracy(vocabulary_data) else: figure_word_acc = plot_word_accuracy(vocabulary) - + stats_layout = [ dbc.Row(dbc.Col(html.H5(children='Global Statistics'), class_name='text-secondary'), class_name='mt-3'), dbc.Row( @@ -827,7 +827,7 @@ def absolute_audio_filepath(audio_filepath, audio_base_path): wordstable_columns = [{'name': 'Word', 'id': 'Word'}, {'name': 'Count', 'id': 'Amount'}] -if args.gpu: +if args.gpu: vocabulary_columns = vocabulary.columns else: vocabulary_columns = vocabulary[0].keys() @@ -910,7 +910,7 @@ def update_wordstable(page_current, sort_by, filter_query): if op in ('eq', 'ne', 'lt', 'le', 'gt', 'ge'): if args.gpu: vocabulary_view = vocabulary_view.loc[getattr(operator, op)(vocabulary_view[col_name], filter_value)] - else: + else: vocabulary_view = [x for x in vocabulary_view if getattr(operator, op)(x[col_name], filter_value)] elif op == 'contains': vocabulary_view = [x for x in vocabulary_view if filter_value in str(x[col_name])] @@ -918,14 +918,14 @@ def update_wordstable(page_current, sort_by, filter_query): if len(sort_by): col = sort_by[0]['column_id'] ascending = sort_by[0]['direction'] != 'desc' - + if args.gpu: vocabulary_view = vocabulary_view.sort_values(col, ascending=ascending) else: vocabulary_view = sorted(vocabulary_view, key=lambda x: x[col], reverse=descending) if page_current * DATA_PAGE_SIZE >= len(vocabulary_view): page_current = len(vocabulary_view) // DATA_PAGE_SIZE - + if args.gpu: return [ vocabulary_view[page_current * DATA_PAGE_SIZE : (page_current + 1) * DATA_PAGE_SIZE].to_dict('records'), @@ -937,6 +937,7 @@ def update_wordstable(page_current, sort_by, filter_query): math.ceil(len(vocabulary_view) / DATA_PAGE_SIZE), ] + if args.gpu: col_names = data.columns else: @@ -1564,7 +1565,7 @@ def update_datatable(page_current, sort_by, filter_query): if op in ('eq', 'ne', 'lt', 'le', 'gt', 'ge'): if args.gpu: data_view = data_view.loc[getattr(operator, op)(data_view[col_name], filter_value)] - else: + else: data_view = [x for x in data_view if getattr(operator, op)(x[col_name], filter_value)] elif op == 'contains': data_view = [x for x in data_view if filter_value in str(x[col_name])] @@ -1572,14 +1573,14 @@ def update_datatable(page_current, sort_by, filter_query): if len(sort_by): col = sort_by[0]['column_id'] ascending = sort_by[0]['direction'] != 'desc' - + if args.gpu: data_view = data_view.sort_values(col, ascending=ascending) else: data_view = sorted(data_view, key=lambda x: x[col], reverse=descending) if page_current * DATA_PAGE_SIZE >= len(data_view): page_current = len(data_view) // DATA_PAGE_SIZE - + if args.gpu: return [ data_view[page_current * DATA_PAGE_SIZE : (page_current + 1) * DATA_PAGE_SIZE].to_dict('records'), diff --git a/tools/speech_data_explorer/sde/dataloader/dataset.py b/tools/speech_data_explorer/sde/dataloader/dataset.py index 085e0a9fa519..479fe23db15e 100644 --- a/tools/speech_data_explorer/sde/dataloader/dataset.py +++ b/tools/speech_data_explorer/sde/dataloader/dataset.py @@ -1,28 +1,40 @@ import json -from tqdm import tqdm import multiprocessing as mp -import pickle import os +import pickle from datetime import datetime +from tqdm import tqdm + from nemo.utils import logging -from .sample import Sample + from .engines.cudf_engine import cuDF +from .sample import Sample + class Dataset: - def __init__(self, manifest_filepath: str, chunksize: int = 10000, data_engine: object = None, n_jobs: int = -1, - reference_field = "text", hypothesis_fields: list[str] = ["pred_text"], hypothesis_labels: list[str] = None, - estimate_audio_metrics: bool = False, - enable_plk: bool = True, plk_filepath: str = None): + def __init__( + self, + manifest_filepath: str, + chunksize: int = 10000, + data_engine: object = None, + n_jobs: int = -1, + reference_field="text", + hypothesis_fields: list[str] = ["pred_text"], + hypothesis_labels: list[str] = None, + estimate_audio_metrics: bool = False, + enable_plk: bool = True, + plk_filepath: str = None, + ): self.manifest_filepath = manifest_filepath self.chunksize = chunksize self.data_engine = data_engine self.n_jobs = n_jobs - + max_jobs = mp.cpu_count() if self.n_jobs == -1 or n_jobs > max_jobs: self.n_jobs = max_jobs - + self.reference_field = reference_field self.hypothesis_fields = hypothesis_fields self.hypothesis_labels = hypothesis_labels @@ -31,16 +43,16 @@ def __init__(self, manifest_filepath: str, chunksize: int = 10000, data_engine: self.enable_plk = enable_plk self.plk_filepath = plk_filepath self.chunks = [] - + self.num_words = 0 self.num_chars = 0 self.duration = 0 self.charset = set() self.words_frequencies = dict() - + self.samples_data = [] self.vocabulary_data = [] - + def _check_hypotheses(self, manifest_line: str): if self.hypothesis_fields is not None: if self.hypothesis_labels is None: @@ -48,13 +60,15 @@ def _check_hypotheses(self, manifest_line: str): self.hypothesis_labels = [""] else: self.hypothesis_labels = list(range(1, len(self.hypothesis_fields) + 1)) - + if len(self.hypothesis_labels) != len(self.hypothesis_fields): - logging.error(f"Amount of hypothesis_labels ({len(self.hypothesis_labels)}) is not equal to amount of hypothesis_fields ({len(self.hypothesis_fields)}).") + logging.error( + f"Amount of hypothesis_labels ({len(self.hypothesis_labels)}) is not equal to amount of hypothesis_fields ({len(self.hypothesis_fields)})." + ) raise else: sample_to_check = json.loads(manifest_line) - + i = 0 while i < len(self.hypothesis_fields): hypothesis_field = self.hypothesis_fields[i] @@ -63,103 +77,121 @@ def _check_hypotheses(self, manifest_line: str): self.hypothesis_fields.pop(i) self.hypothesis_labels.pop(i) else: - logging.info(f"Field '{hypothesis_field}' was found (labeled as '{self.hypothesis_labels[i]}').") - self.hypotheses[hypothesis_field] = HypothesisMetrics(hypothesis_label = self.hypothesis_labels[i]) + logging.info( + f"Field '{hypothesis_field}' was found (labeled as '{self.hypothesis_labels[i]}')." + ) + self.hypotheses[hypothesis_field] = HypothesisMetrics( + hypothesis_label=self.hypothesis_labels[i] + ) i += 1 - - def _read_manifest(self): + + def _read_manifest(self): logging.info("Reading manifest..") - with open(self.manifest_filepath, 'r', encoding = "utf8") as manifest: + with open(self.manifest_filepath, 'r', encoding="utf8") as manifest: lines = manifest.readlines() - + self._check_hypotheses(lines[0]) - + lines_amount = len(lines) logging.info(f"Lines amount: {lines_amount}. Splitting to chunks ({self.chunksize} lines per chunk)..") - + start_chunk_indicies = list(range(0, lines_amount, self.chunksize)) end_chunk_indicies = list(range(self.chunksize, lines_amount, self.chunksize)) + [lines_amount] - - for start_idx, end_idx in tqdm(zip(start_chunk_indicies, end_chunk_indicies), total = len(start_chunk_indicies)): - chunk = DataChunk(manifest_lines = lines[start_idx : end_idx], data_engine = self.data_engine, reference_field = self.reference_field, - hypothesis_fields = self.hypothesis_fields, hypothesis_labels = self.hypothesis_labels, - estimate_audio_metrics = self.estimate_audio_metrics) + + for start_idx, end_idx in tqdm(zip(start_chunk_indicies, end_chunk_indicies), total=len(start_chunk_indicies)): + chunk = DataChunk( + manifest_lines=lines[start_idx:end_idx], + data_engine=self.data_engine, + reference_field=self.reference_field, + hypothesis_fields=self.hypothesis_fields, + hypothesis_labels=self.hypothesis_labels, + estimate_audio_metrics=self.estimate_audio_metrics, + ) self.chunks.append(chunk) - + def _get_plk_filepath(self): timestamp = datetime.fromtimestamp(os.path.getmtime(self.manifest_filepath)).strftime('%Y-%m-%d_%H-%M-%S') return f"{self.manifest_filepath.replace('.json', '')}_{timestamp}.pkl" - + def _read_pickle(self): with open(self.plk_filepath, 'rb') as pkl: return pickle.load(pkl) - + def _write_pickle(self): logging.info(f'Saving .plk file..') with open(self.plk_filepath, 'wb') as pkl: pickle.dump(self, pkl, pickle.HIGHEST_PROTOCOL) - + logging.info(f'{self.plk_filepath} saved.') - + def process(self): if self.enable_plk: logging.info(f'Looking for .plk file ({self.plk_filepath})') if self.plk_filepath is None: self.plk_filepath = self._get_plk_filepath() - + if os.path.exists(self.plk_filepath): logging.info(f'{self.plk_filepath} found.') return self._read_pickle() else: logging.info(f'{self.plk_filepath} not found. Loading from data from manifest..') - + self._read_manifest() - + processed_chunks = [] logging.info(f'Samples processing ({self.n_jobs} processes)..') with mp.Pool(self.n_jobs) as pool: for processed_chunk in tqdm(pool.imap(DataChunk.process, self.chunks), total=len(self.chunks)): processed_chunks.append(processed_chunk) - + self.chunks = processed_chunks - + logging.info(f'Global metrics computing..') for chunk in tqdm(self.chunks): self.num_words += chunk.num_words self.num_chars += chunk.num_chars self.duration += chunk.duration self.charset.update(chunk.charset) - + for hypothesis_field in chunk.hypotheses: self.hypotheses[hypothesis_field].update(chunk.hypotheses[hypothesis_field]) - + for word in chunk.words_frequencies: self.words_frequencies[word] = self.words_frequencies.get(word, 0) + chunk.words_frequencies[word] - + if self.data_engine is not None: self.samples_data.append(chunk.samples_data) - + for hypothesis_field in self.hypotheses: - self.hypotheses[hypothesis_field].compute(dataset_num_words = self.num_words, dataset_num_chars = self.num_chars) - - self.duration = round(self.duration / 3600 , 2) - + self.hypotheses[hypothesis_field].compute( + dataset_num_words=self.num_words, dataset_num_chars=self.num_chars + ) + + self.duration = round(self.duration / 3600, 2) + if self.data_engine is not None: logging.info(f'Samples datatable loading..') self.samples_data = self.data_engine.concat_samples_chunks(self.samples_data) - self.vocabulary_data = self.data_engine.process_vocabulary(words_frequencies = self.words_frequencies, - hypotheses_metrics = self.hypotheses.values()) - + self.vocabulary_data = self.data_engine.process_vocabulary( + words_frequencies=self.words_frequencies, hypotheses_metrics=self.hypotheses.values() + ) + if self.enable_plk: - self._write_pickle() - + self._write_pickle() + return self - + class DataChunk: - def __init__(self, manifest_lines: list[str], data_engine: object = None, - reference_field: str = "text", hypothesis_fields: list[str] = ["pred_text"], hypothesis_labels: list[str] = None, - estimate_audio_metrics: bool = False): + def __init__( + self, + manifest_lines: list[str], + data_engine: object = None, + reference_field: str = "text", + hypothesis_fields: list[str] = ["pred_text"], + hypothesis_labels: list[str] = None, + estimate_audio_metrics: bool = False, + ): self.manifest_lines = manifest_lines self.reference_field = reference_field self.estimate_audio_metrics = estimate_audio_metrics @@ -168,70 +200,76 @@ def __init__(self, manifest_lines: list[str], data_engine: object = None, self.num_chars = 0 self.duration = 0 self.charset = set() - self.words_frequencies = dict() - + self.words_frequencies = dict() + self.hypothesis_fields = hypothesis_fields self.hypothesis_labels = hypothesis_labels self.hypotheses = dict() - + for field, label in zip(hypothesis_fields, hypothesis_labels): - self.hypotheses[field] = HypothesisMetrics(hypothesis_label = label) - - self.data_engine = data_engine + self.hypotheses[field] = HypothesisMetrics(hypothesis_label=label) + + self.data_engine = data_engine self.samples_data = None - + def process(self): sample = Sample() for manifest_line in self.manifest_lines: - sample.parse_line(manifest_line, reference_field = self.reference_field, - hypothesis_fields = self.hypothesis_fields, - hypothesis_labels = self.hypothesis_labels) - sample.compute(estimate_audio_metrics = self.estimate_audio_metrics) - + sample.parse_line( + manifest_line, + reference_field=self.reference_field, + hypothesis_fields=self.hypothesis_fields, + hypothesis_labels=self.hypothesis_labels, + ) + sample.compute(estimate_audio_metrics=self.estimate_audio_metrics) + self.samples_dicts.append(sample.sample_dict) self.num_words += sample.num_words self.num_chars += sample.num_chars self.duration += sample.duration self.charset.update(sample.charset) - + for word in sample.words_frequencies: self.words_frequencies[word] = self.words_frequencies.get(word, 0) + sample.words_frequencies[word] - + for hypothesis_field in sample.hypotheses: self.hypotheses[hypothesis_field].update(sample.hypotheses[hypothesis_field]) - + sample.reset() - + if self.data_engine is not None: self.samples_data = self.data_engine.load_samples_chunk(self.samples_dicts) self.samples_dicts = {} - + return self - + + class HypothesisMetrics: def __init__(self, hypothesis_label: str = None): self.hypothesis_label = hypothesis_label self.word_distance = 0 self.word_match = 0 self.char_distance = 0 - + self.wer = None self.wmr = None self.cer = None self.mwa = None - + self.match_words_frequencies = dict() - def update(self, hypothesis: object): + def update(self, hypothesis: object): assert self.hypothesis_label == hypothesis.hypothesis_label, "Hypothesis label mismatch!" - + self.word_distance += hypothesis.word_distance self.word_match += hypothesis.word_match self.char_distance += hypothesis.char_distance - + for word in hypothesis.match_words_frequencies: - self.match_words_frequencies[word] = self.match_words_frequencies.get(word, 0) + hypothesis.match_words_frequencies[word] - + self.match_words_frequencies[word] = ( + self.match_words_frequencies.get(word, 0) + hypothesis.match_words_frequencies[word] + ) + def compute(self, dataset_num_words: int, dataset_num_chars: int): self.wer = round(self.word_distance / dataset_num_words * 100.0, 2) self.wmr = round(self.word_match / dataset_num_words * 100.0, 2) diff --git a/tools/speech_data_explorer/sde/dataloader/engines/cudf_engine.py b/tools/speech_data_explorer/sde/dataloader/engines/cudf_engine.py index f07fb15955ba..2e1d5a041e0e 100644 --- a/tools/speech_data_explorer/sde/dataloader/engines/cudf_engine.py +++ b/tools/speech_data_explorer/sde/dataloader/engines/cudf_engine.py @@ -1,4 +1,5 @@ import cudf.pandas + cudf.pandas.install() import pandas as pd @@ -6,38 +7,42 @@ class cuDF: def __init__(self): pass - + def load_samples_chunk(self, samples: list[dict]): chunk = pd.DataFrame(samples) return chunk - + def concat_samples_chunks(self, samples_chunks: list): samples_datatable = pd.concat(samples_chunks).reset_index(drop=True) return samples_datatable - + def process_vocabulary(self, words_frequencies: dict, hypotheses_metrics: list[object]): vocabulary_dfs = [] - + words_frequencies_df = pd.DataFrame(words_frequencies.items(), columns=["Word", "Amount"]).set_index("Word") vocabulary_dfs.append(words_frequencies_df) - + for hypothesis_metrics_obj in hypotheses_metrics: label = hypothesis_metrics_obj.hypothesis_label match_words_frequencies = hypothesis_metrics_obj.match_words_frequencies - match_words_frequencies_df = pd.DataFrame(match_words_frequencies.items(), columns=["Word", f"Match_{hypothesis_metrics_obj.hypothesis_label}"]).set_index("Word") + match_words_frequencies_df = pd.DataFrame( + match_words_frequencies.items(), columns=["Word", f"Match_{hypothesis_metrics_obj.hypothesis_label}"] + ).set_index("Word") vocabulary_dfs.append(match_words_frequencies_df) - - vocabulary_datatable = pd.concat(vocabulary_dfs, axis = 1, join = "outer").reset_index().fillna(0) - + + vocabulary_datatable = pd.concat(vocabulary_dfs, axis=1, join="outer").reset_index().fillna(0) + for hypothesis_metrics_obj in hypotheses_metrics: label = hypothesis_metrics_obj.hypothesis_label postfix = "" if label != "": postfix = f"_{label}" - - vocabulary_datatable[f"Accuracy{postfix}"] = vocabulary_datatable[f"Match_{label}"] / vocabulary_datatable["Amount"] * 100 + + vocabulary_datatable[f"Accuracy{postfix}"] = ( + vocabulary_datatable[f"Match_{label}"] / vocabulary_datatable["Amount"] * 100 + ) vocabulary_datatable[f"Accuracy{postfix}"] = vocabulary_datatable[f"Accuracy{postfix}"].round(2) - vocabulary_datatable = vocabulary_datatable.drop(f"Match_{label}", axis = 1) + vocabulary_datatable = vocabulary_datatable.drop(f"Match_{label}", axis=1) hypothesis_metrics_obj.mwa = round(vocabulary_datatable[f"Accuracy{postfix}"].mean(), 2) - + return vocabulary_datatable diff --git a/tools/speech_data_explorer/sde/dataloader/sample.py b/tools/speech_data_explorer/sde/dataloader/sample.py index 2cebc9e0532e..ab0c22189d7e 100644 --- a/tools/speech_data_explorer/sde/dataloader/sample.py +++ b/tools/speech_data_explorer/sde/dataloader/sample.py @@ -1,10 +1,11 @@ import json from collections import Counter -import jiwer from difflib import SequenceMatcher + import editdistance -import numpy as np +import jiwer import librosa +import numpy as np class Sample: @@ -31,20 +32,24 @@ def reset(self): self.frequency_bandwidth = None self.level_db = None self.hypotheses = {} - - def parse_line(self, manifest_line: str, reference_field: str = "text", - hypothesis_fields: list[str] = ["pred_text"], - hypothesis_labels: list[str] = None): - + + def parse_line( + self, + manifest_line: str, + reference_field: str = "text", + hypothesis_fields: list[str] = ["pred_text"], + hypothesis_labels: list[str] = None, + ): + self.sample_dict = json.loads(manifest_line) self.reference_text = self.sample_dict.get(reference_field, None) self.duration = self.sample_dict.get("duration", None) - + if hypothesis_labels is None: hypothesis_labels = list(range(1, len(hypothesis_fields) + 1)) - + for field, label in zip(hypothesis_fields, hypothesis_labels): - hypothesis = Hypothesis(hypothesis_text = self.sample_dict[field], hypothesis_label = label) + hypothesis = Hypothesis(hypothesis_text=self.sample_dict[field], hypothesis_label=label) self.hypotheses[field] = hypothesis def compute(self, estimate_audio_metrics: bool = False): @@ -53,24 +58,29 @@ def compute(self, estimate_audio_metrics: bool = False): self.num_words = len(self.words) self.charset = set(self.reference_text) self.words_frequencies = dict(Counter(self.words)) - + if self.duration is not None: self.char_rate = round(self.num_chars / self.duration, 2) self.word_rate = round(self.num_chars / self.duration, 2) - + if len(self.hypotheses) != 0: for label in self.hypotheses: - self.hypotheses[label].compute(reference_text = self.reference_text, reference_words = self.words, - reference_num_words = self.num_words, reference_num_chars = self.num_chars) - + self.hypotheses[label].compute( + reference_text=self.reference_text, + reference_words=self.words, + reference_num_words=self.num_words, + reference_num_chars=self.num_chars, + ) + if estimate_audio_metrics and self.audio_filepath is not None: - + def eval_signal_frequency_bandwidth(self, signal, sampling_rate, threshold=-50) -> float: time_stride = 0.01 hop_length = int(sampling_rate * time_stride) n_fft = 512 spectrogram = np.mean( - np.abs(librosa.stft(y=signal, n_fft=n_fft, hop_length=hop_length, window='blackmanharris')) ** 2, axis=1 + np.abs(librosa.stft(y=signal, n_fft=n_fft, hop_length=hop_length, window='blackmanharris')) ** 2, + axis=1, ) power_spectrum = librosa.power_to_db(S=spectrogram, ref=np.max, top_db=100) frequency_bandwidth = 0 @@ -78,25 +88,27 @@ def eval_signal_frequency_bandwidth(self, signal, sampling_rate, threshold=-50) if power_spectrum[idx] > threshold: frequency_bandwidth = idx / n_fft * sampling_rate break - + return frequency_bandwidth - + self.signal, self.sampling_rate = librosa.load(path=self.audio_filepath, sr=None) - self.frequency_bandwidth = eval_signal_frequency_bandwidth(signal=self.signal, sampling_rate=self.sampling_rate) + self.frequency_bandwidth = eval_signal_frequency_bandwidth( + signal=self.signal, sampling_rate=self.sampling_rate + ) self.level_db = 20 * np.log10(np.max(np.abs(self.signal))) self.add_table_metrics_to_dict() - + def add_table_metrics_to_dict(self): metrics = { "num_chars": self.num_chars, "num_words": self.num_words, } - + if self.duration is not None: metrics["char_rate"] = self.char_rate metrics["word_rate"] = self.word_rate - + if len(self.hypotheses) != 0: for label in self.hypotheses: hypothesis_metrics = self.hypotheses[label].get_table_metrics() @@ -105,7 +117,7 @@ def add_table_metrics_to_dict(self): if self.frequency_bandwidth is not None: metrics["freq_bandwidth"] = self.frequency_bandwidth metrics["level_db"] = self.level_db - + self.sample_dict.update(metrics) @@ -114,7 +126,7 @@ def __init__(self, hypothesis_text: str, hypothesis_label: str = None): self.hypothesis_text = hypothesis_text self.hypothesis_label = hypothesis_label self.hypothesis_words = None - + self.wer = None self.wmr = None self.num_insertions = None @@ -123,23 +135,28 @@ def __init__(self, hypothesis_text: str, hypothesis_label: str = None): self.word_match = None self.word_distance = None self.match_words_frequencies = dict() - + self.char_distance = None self.cer = None - - def compute(self, reference_text: str, reference_words: list[str] = None, - reference_num_words: int = None, reference_num_chars: int = None): - + + def compute( + self, + reference_text: str, + reference_words: list[str] = None, + reference_num_words: int = None, + reference_num_chars: int = None, + ): + if reference_words is None: reference_words = reference_text.split() if reference_num_words is None: reference_num_words = len(reference_words) if reference_num_chars is None: reference_num_chars = len(reference_text) - + self.hypothesis_words = self.hypothesis_text.split() - - #word match metrics + + # word match metrics measures = jiwer.compute_measures(reference_text, self.hypothesis_text) self.wer = round(measures['wer'] * 100.0, 2) @@ -149,28 +166,34 @@ def compute(self, reference_text: str, reference_words: list[str] = None, self.deletions_insertions_diff = self.num_deletions - self.num_insertions self.word_match = measures['hits'] self.word_distance = measures['substitutions'] + measures['insertions'] + measures['deletions'] - + sm = SequenceMatcher() sm.set_seqs(reference_words, self.hypothesis_words) - self.match_words_frequencies = dict(Counter([reference_words[word_idx] - for match in sm.get_matching_blocks() - for word_idx in range(match[0], match[0] + match[2])])) - - #char match metrics + self.match_words_frequencies = dict( + Counter( + [ + reference_words[word_idx] + for match in sm.get_matching_blocks() + for word_idx in range(match[0], match[0] + match[2]) + ] + ) + ) + + # char match metrics self.char_distance = editdistance.eval(reference_text, self.hypothesis_text) self.cer = round(self.char_distance / reference_num_chars * 100.0, 2) - + def get_table_metrics(self): postfix = "" if self.hypothesis_label != "": postfix = f"_{self.hypothesis_label}" - + metrics = { - f"WER{postfix}" : self.wer, - f"CER{postfix}" : self.cer, - f"WMR{postfix}" : self.wmr, - f"I{postfix}" : self.num_insertions, - f"D{postfix}" : self.num_deletions, - f"D-I{postfix}" : self.deletions_insertions_diff + f"WER{postfix}": self.wer, + f"CER{postfix}": self.cer, + f"WMR{postfix}": self.wmr, + f"I{postfix}": self.num_insertions, + f"D{postfix}": self.num_deletions, + f"D-I{postfix}": self.deletions_insertions_diff, } return metrics