From fcb7aa95c1f5015d3c092249e2f7c11dc8da978d Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Mon, 27 Nov 2023 14:54:37 +0000
Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 tools/speech_data_explorer/data_explorer.py   |  53 ++---
 .../sde/dataloader/dataset.py                 | 198 +++++++++++-------
 .../sde/dataloader/engines/cudf_engine.py     |  31 +--
 .../sde/dataloader/sample.py                  | 117 ++++++-----
 4 files changed, 233 insertions(+), 166 deletions(-)

diff --git a/tools/speech_data_explorer/data_explorer.py b/tools/speech_data_explorer/data_explorer.py
index 264c191217a7..b3193a03b535 100755
--- a/tools/speech_data_explorer/data_explorer.py
+++ b/tools/speech_data_explorer/data_explorer.py
@@ -121,10 +121,7 @@ def parse_args():
         help='field name for which you want to see statistics (optional). Example: pred_text_contextnet.',
     )
     parser.add_argument(
-        '--gpu',
-        '-gpu',
-        action='store_true',
-        help='use GPU-acceleration',
+        '--gpu', '-gpu', action='store_true', help='use GPU-acceleration',
     )
     args = parser.parse_args()
 
@@ -490,7 +487,7 @@ def plot_histogram(data, key, label, gpu_acceleration=False):
         data_frame = data[key].to_list()
     else:
         data_frame = [item[key] for item in data]
-    
+
     fig = px.histogram(
         data_frame=data_frame,
         nbins=50,
@@ -504,10 +501,10 @@ def plot_histogram(data, key, label, gpu_acceleration=False):
     return fig
 
 
-def plot_word_accuracy(vocabulary_data):            
+def plot_word_accuracy(vocabulary_data):
     labels = ['Unrecognized', 'Sometimes recognized', 'Always recognized']
     counts = [0, 0, 0]
-    
+
     if args.gpu:
         counts[0] = (vocabulary_data['Accuracy'] == 0).sum()
         counts[1] = (vocabulary_data['Accuracy'] < 100).sum()
@@ -576,24 +573,27 @@ def absolute_audio_filepath(audio_filepath, audio_base_path):
 if args.gpu:
     if args.names_compared is not None:
         raise Exception(f"Currently comparision mode is not available with gpu acceleation.")
-    
+
     hypothesis_fields = ["pred_text"]
     if args.show_statistics is not None:
         hypothesis_fields = [args.show_statistics]
-        
+
     enable_plk = True
     if args.disable_caching_metrics:
         enable_plk = False
-    
+
     cu_df = cuDF()
 
-    dataset = Dataset(manifest_filepath = args.manifest, data_engine = cu_df,
-                      hypothesis_fields = hypothesis_fields, 
-                      estimate_audio_metrics = args.estimate_audio_metrics,
-                      enable_plk = enable_plk)
-    
+    dataset = Dataset(
+        manifest_filepath=args.manifest,
+        data_engine=cu_df,
+        hypothesis_fields=hypothesis_fields,
+        estimate_audio_metrics=args.estimate_audio_metrics,
+        enable_plk=enable_plk,
+    )
+
     dataset = dataset.process()
-    
+
     data = dataset.samples_data
     num_hours = dataset.duration
     vocabulary = dataset.vocabulary_data
@@ -602,8 +602,8 @@ def absolute_audio_filepath(audio_filepath, audio_base_path):
     metrics_available = len(dataset.hypotheses) != 0
     if metrics_available:
         wer = dataset.hypotheses[hypothesis_fields[0]].wer
-        cer = dataset.hypotheses[hypothesis_fields[0]].cer 
-        wmr = dataset.hypotheses[hypothesis_fields[0]].wmr 
+        cer = dataset.hypotheses[hypothesis_fields[0]].cer
+        wmr = dataset.hypotheses[hypothesis_fields[0]].wmr
         mwa = dataset.hypotheses[hypothesis_fields[0]].mwa
 else:
     if not comparison_mode:
@@ -706,7 +706,7 @@ def absolute_audio_filepath(audio_filepath, audio_base_path):
         figure_word_acc = plot_word_accuracy(vocabulary_data)
     else:
         figure_word_acc = plot_word_accuracy(vocabulary)
-        
+
 stats_layout = [
     dbc.Row(dbc.Col(html.H5(children='Global Statistics'), class_name='text-secondary'), class_name='mt-3'),
     dbc.Row(
@@ -827,7 +827,7 @@ def absolute_audio_filepath(audio_filepath, audio_base_path):
 
 wordstable_columns = [{'name': 'Word', 'id': 'Word'}, {'name': 'Count', 'id': 'Amount'}]
 
-if args.gpu:    
+if args.gpu:
     vocabulary_columns = vocabulary.columns
 else:
     vocabulary_columns = vocabulary[0].keys()
@@ -910,7 +910,7 @@ def update_wordstable(page_current, sort_by, filter_query):
         if op in ('eq', 'ne', 'lt', 'le', 'gt', 'ge'):
             if args.gpu:
                 vocabulary_view = vocabulary_view.loc[getattr(operator, op)(vocabulary_view[col_name], filter_value)]
-            else:    
+            else:
                 vocabulary_view = [x for x in vocabulary_view if getattr(operator, op)(x[col_name], filter_value)]
         elif op == 'contains':
             vocabulary_view = [x for x in vocabulary_view if filter_value in str(x[col_name])]
@@ -918,14 +918,14 @@ def update_wordstable(page_current, sort_by, filter_query):
     if len(sort_by):
         col = sort_by[0]['column_id']
         ascending = sort_by[0]['direction'] != 'desc'
-        
+
         if args.gpu:
             vocabulary_view = vocabulary_view.sort_values(col, ascending=ascending)
         else:
             vocabulary_view = sorted(vocabulary_view, key=lambda x: x[col], reverse=descending)
     if page_current * DATA_PAGE_SIZE >= len(vocabulary_view):
         page_current = len(vocabulary_view) // DATA_PAGE_SIZE
-    
+
     if args.gpu:
         return [
             vocabulary_view[page_current * DATA_PAGE_SIZE : (page_current + 1) * DATA_PAGE_SIZE].to_dict('records'),
@@ -937,6 +937,7 @@ def update_wordstable(page_current, sort_by, filter_query):
             math.ceil(len(vocabulary_view) / DATA_PAGE_SIZE),
         ]
 
+
 if args.gpu:
     col_names = data.columns
 else:
@@ -1564,7 +1565,7 @@ def update_datatable(page_current, sort_by, filter_query):
         if op in ('eq', 'ne', 'lt', 'le', 'gt', 'ge'):
             if args.gpu:
                 data_view = data_view.loc[getattr(operator, op)(data_view[col_name], filter_value)]
-            else: 
+            else:
                 data_view = [x for x in data_view if getattr(operator, op)(x[col_name], filter_value)]
         elif op == 'contains':
             data_view = [x for x in data_view if filter_value in str(x[col_name])]
@@ -1572,14 +1573,14 @@ def update_datatable(page_current, sort_by, filter_query):
     if len(sort_by):
         col = sort_by[0]['column_id']
         ascending = sort_by[0]['direction'] != 'desc'
-        
+
         if args.gpu:
             data_view = data_view.sort_values(col, ascending=ascending)
         else:
             data_view = sorted(data_view, key=lambda x: x[col], reverse=descending)
     if page_current * DATA_PAGE_SIZE >= len(data_view):
         page_current = len(data_view) // DATA_PAGE_SIZE
-        
+
     if args.gpu:
         return [
             data_view[page_current * DATA_PAGE_SIZE : (page_current + 1) * DATA_PAGE_SIZE].to_dict('records'),
diff --git a/tools/speech_data_explorer/sde/dataloader/dataset.py b/tools/speech_data_explorer/sde/dataloader/dataset.py
index 085e0a9fa519..479fe23db15e 100644
--- a/tools/speech_data_explorer/sde/dataloader/dataset.py
+++ b/tools/speech_data_explorer/sde/dataloader/dataset.py
@@ -1,28 +1,40 @@
 import json
-from tqdm import tqdm
 import multiprocessing as mp
-import pickle
 import os
+import pickle
 from datetime import datetime
 
+from tqdm import tqdm
+
 from nemo.utils import logging
-from .sample import Sample
+
 from .engines.cudf_engine import cuDF
+from .sample import Sample
+
 
 class Dataset:
-    def __init__(self, manifest_filepath: str, chunksize: int = 10000, data_engine: object = None, n_jobs: int = -1,
-                 reference_field = "text", hypothesis_fields: list[str] = ["pred_text"], hypothesis_labels: list[str] = None,
-                 estimate_audio_metrics: bool = False,
-                 enable_plk: bool = True, plk_filepath: str = None):
+    def __init__(
+        self,
+        manifest_filepath: str,
+        chunksize: int = 10000,
+        data_engine: object = None,
+        n_jobs: int = -1,
+        reference_field="text",
+        hypothesis_fields: list[str] = ["pred_text"],
+        hypothesis_labels: list[str] = None,
+        estimate_audio_metrics: bool = False,
+        enable_plk: bool = True,
+        plk_filepath: str = None,
+    ):
         self.manifest_filepath = manifest_filepath
         self.chunksize = chunksize
         self.data_engine = data_engine
         self.n_jobs = n_jobs
-        
+
         max_jobs = mp.cpu_count()
         if self.n_jobs == -1 or n_jobs > max_jobs:
             self.n_jobs = max_jobs
-        
+
         self.reference_field = reference_field
         self.hypothesis_fields = hypothesis_fields
         self.hypothesis_labels = hypothesis_labels
@@ -31,16 +43,16 @@ def __init__(self, manifest_filepath: str, chunksize: int = 10000, data_engine:
         self.enable_plk = enable_plk
         self.plk_filepath = plk_filepath
         self.chunks = []
-        
+
         self.num_words = 0
         self.num_chars = 0
         self.duration = 0
         self.charset = set()
         self.words_frequencies = dict()
-        
+
         self.samples_data = []
         self.vocabulary_data = []
-    
+
     def _check_hypotheses(self, manifest_line: str):
         if self.hypothesis_fields is not None:
             if self.hypothesis_labels is None:
@@ -48,13 +60,15 @@ def _check_hypotheses(self, manifest_line: str):
                     self.hypothesis_labels = [""]
                 else:
                     self.hypothesis_labels = list(range(1, len(self.hypothesis_fields) + 1))
-                
+
             if len(self.hypothesis_labels) != len(self.hypothesis_fields):
-                logging.error(f"Amount of hypothesis_labels ({len(self.hypothesis_labels)}) is not equal to amount of hypothesis_fields ({len(self.hypothesis_fields)}).")
+                logging.error(
+                    f"Amount of hypothesis_labels ({len(self.hypothesis_labels)}) is not equal to amount of hypothesis_fields ({len(self.hypothesis_fields)})."
+                )
                 raise
             else:
                 sample_to_check = json.loads(manifest_line)
-                
+
                 i = 0
                 while i < len(self.hypothesis_fields):
                     hypothesis_field = self.hypothesis_fields[i]
@@ -63,103 +77,121 @@ def _check_hypotheses(self, manifest_line: str):
                         self.hypothesis_fields.pop(i)
                         self.hypothesis_labels.pop(i)
                     else:
-                        logging.info(f"Field '{hypothesis_field}' was found (labeled as '{self.hypothesis_labels[i]}').")
-                        self.hypotheses[hypothesis_field] = HypothesisMetrics(hypothesis_label = self.hypothesis_labels[i])
+                        logging.info(
+                            f"Field '{hypothesis_field}' was found (labeled as '{self.hypothesis_labels[i]}')."
+                        )
+                        self.hypotheses[hypothesis_field] = HypothesisMetrics(
+                            hypothesis_label=self.hypothesis_labels[i]
+                        )
                         i += 1
-            
-    def _read_manifest(self): 
+
+    def _read_manifest(self):
         logging.info("Reading manifest..")
-        with open(self.manifest_filepath, 'r', encoding = "utf8") as manifest:
+        with open(self.manifest_filepath, 'r', encoding="utf8") as manifest:
             lines = manifest.readlines()
-        
+
         self._check_hypotheses(lines[0])
-        
+
         lines_amount = len(lines)
         logging.info(f"Lines amount: {lines_amount}. Splitting to chunks ({self.chunksize} lines per chunk)..")
-        
+
         start_chunk_indicies = list(range(0, lines_amount, self.chunksize))
         end_chunk_indicies = list(range(self.chunksize, lines_amount, self.chunksize)) + [lines_amount]
-        
-        for start_idx, end_idx in tqdm(zip(start_chunk_indicies, end_chunk_indicies), total = len(start_chunk_indicies)):
-            chunk = DataChunk(manifest_lines = lines[start_idx : end_idx], data_engine = self.data_engine, reference_field = self.reference_field, 
-                              hypothesis_fields = self.hypothesis_fields, hypothesis_labels = self.hypothesis_labels,
-                              estimate_audio_metrics = self.estimate_audio_metrics)
+
+        for start_idx, end_idx in tqdm(zip(start_chunk_indicies, end_chunk_indicies), total=len(start_chunk_indicies)):
+            chunk = DataChunk(
+                manifest_lines=lines[start_idx:end_idx],
+                data_engine=self.data_engine,
+                reference_field=self.reference_field,
+                hypothesis_fields=self.hypothesis_fields,
+                hypothesis_labels=self.hypothesis_labels,
+                estimate_audio_metrics=self.estimate_audio_metrics,
+            )
             self.chunks.append(chunk)
-    
+
     def _get_plk_filepath(self):
         timestamp = datetime.fromtimestamp(os.path.getmtime(self.manifest_filepath)).strftime('%Y-%m-%d_%H-%M-%S')
         return f"{self.manifest_filepath.replace('.json', '')}_{timestamp}.pkl"
-    
+
     def _read_pickle(self):
         with open(self.plk_filepath, 'rb') as pkl:
             return pickle.load(pkl)
-    
+
     def _write_pickle(self):
         logging.info(f'Saving .plk file..')
         with open(self.plk_filepath, 'wb') as pkl:
             pickle.dump(self, pkl, pickle.HIGHEST_PROTOCOL)
-        
+
         logging.info(f'{self.plk_filepath} saved.')
-    
+
     def process(self):
         if self.enable_plk:
             logging.info(f'Looking for .plk file ({self.plk_filepath})')
             if self.plk_filepath is None:
                 self.plk_filepath = self._get_plk_filepath()
-            
+
             if os.path.exists(self.plk_filepath):
                 logging.info(f'{self.plk_filepath} found.')
                 return self._read_pickle()
             else:
                 logging.info(f'{self.plk_filepath} not found. Loading from data from manifest..')
-        
+
         self._read_manifest()
-        
+
         processed_chunks = []
         logging.info(f'Samples processing ({self.n_jobs} processes)..')
         with mp.Pool(self.n_jobs) as pool:
             for processed_chunk in tqdm(pool.imap(DataChunk.process, self.chunks), total=len(self.chunks)):
                 processed_chunks.append(processed_chunk)
-        
+
         self.chunks = processed_chunks
-        
+
         logging.info(f'Global metrics computing..')
         for chunk in tqdm(self.chunks):
             self.num_words += chunk.num_words
             self.num_chars += chunk.num_chars
             self.duration += chunk.duration
             self.charset.update(chunk.charset)
-            
+
             for hypothesis_field in chunk.hypotheses:
                 self.hypotheses[hypothesis_field].update(chunk.hypotheses[hypothesis_field])
-                
+
             for word in chunk.words_frequencies:
                 self.words_frequencies[word] = self.words_frequencies.get(word, 0) + chunk.words_frequencies[word]
-                
+
             if self.data_engine is not None:
                 self.samples_data.append(chunk.samples_data)
-        
+
         for hypothesis_field in self.hypotheses:
-            self.hypotheses[hypothesis_field].compute(dataset_num_words = self.num_words, dataset_num_chars = self.num_chars)
-        
-        self.duration = round(self.duration / 3600 , 2)
-        
+            self.hypotheses[hypothesis_field].compute(
+                dataset_num_words=self.num_words, dataset_num_chars=self.num_chars
+            )
+
+        self.duration = round(self.duration / 3600, 2)
+
         if self.data_engine is not None:
             logging.info(f'Samples datatable loading..')
             self.samples_data = self.data_engine.concat_samples_chunks(self.samples_data)
-            self.vocabulary_data = self.data_engine.process_vocabulary(words_frequencies = self.words_frequencies,
-                                                                       hypotheses_metrics = self.hypotheses.values())
-        
+            self.vocabulary_data = self.data_engine.process_vocabulary(
+                words_frequencies=self.words_frequencies, hypotheses_metrics=self.hypotheses.values()
+            )
+
         if self.enable_plk:
-            self._write_pickle() 
-        
+            self._write_pickle()
+
         return self
-        
+
 
 class DataChunk:
-    def __init__(self, manifest_lines: list[str], data_engine: object = None,
-                  reference_field: str = "text", hypothesis_fields: list[str] = ["pred_text"], hypothesis_labels: list[str] = None,
-                  estimate_audio_metrics: bool = False):
+    def __init__(
+        self,
+        manifest_lines: list[str],
+        data_engine: object = None,
+        reference_field: str = "text",
+        hypothesis_fields: list[str] = ["pred_text"],
+        hypothesis_labels: list[str] = None,
+        estimate_audio_metrics: bool = False,
+    ):
         self.manifest_lines = manifest_lines
         self.reference_field = reference_field
         self.estimate_audio_metrics = estimate_audio_metrics
@@ -168,70 +200,76 @@ def __init__(self, manifest_lines: list[str], data_engine: object = None,
         self.num_chars = 0
         self.duration = 0
         self.charset = set()
-        self.words_frequencies = dict()        
-        
+        self.words_frequencies = dict()
+
         self.hypothesis_fields = hypothesis_fields
         self.hypothesis_labels = hypothesis_labels
         self.hypotheses = dict()
-        
+
         for field, label in zip(hypothesis_fields, hypothesis_labels):
-            self.hypotheses[field] = HypothesisMetrics(hypothesis_label = label)
-            
-        self.data_engine = data_engine 
+            self.hypotheses[field] = HypothesisMetrics(hypothesis_label=label)
+
+        self.data_engine = data_engine
         self.samples_data = None
-        
+
     def process(self):
         sample = Sample()
         for manifest_line in self.manifest_lines:
-            sample.parse_line(manifest_line, reference_field = self.reference_field, 
-                              hypothesis_fields = self.hypothesis_fields, 
-                              hypothesis_labels = self.hypothesis_labels)
-            sample.compute(estimate_audio_metrics = self.estimate_audio_metrics)
-            
+            sample.parse_line(
+                manifest_line,
+                reference_field=self.reference_field,
+                hypothesis_fields=self.hypothesis_fields,
+                hypothesis_labels=self.hypothesis_labels,
+            )
+            sample.compute(estimate_audio_metrics=self.estimate_audio_metrics)
+
             self.samples_dicts.append(sample.sample_dict)
             self.num_words += sample.num_words
             self.num_chars += sample.num_chars
             self.duration += sample.duration
             self.charset.update(sample.charset)
-            
+
             for word in sample.words_frequencies:
                 self.words_frequencies[word] = self.words_frequencies.get(word, 0) + sample.words_frequencies[word]
-                
+
             for hypothesis_field in sample.hypotheses:
                 self.hypotheses[hypothesis_field].update(sample.hypotheses[hypothesis_field])
-            
+
             sample.reset()
-        
+
         if self.data_engine is not None:
             self.samples_data = self.data_engine.load_samples_chunk(self.samples_dicts)
             self.samples_dicts = {}
-        
+
         return self
-            
+
+
 class HypothesisMetrics:
     def __init__(self, hypothesis_label: str = None):
         self.hypothesis_label = hypothesis_label
         self.word_distance = 0
         self.word_match = 0
         self.char_distance = 0
-    
+
         self.wer = None
         self.wmr = None
         self.cer = None
         self.mwa = None
-                
+
         self.match_words_frequencies = dict()
 
-    def update(self, hypothesis: object):   
+    def update(self, hypothesis: object):
         assert self.hypothesis_label == hypothesis.hypothesis_label, "Hypothesis label mismatch!"
-        
+
         self.word_distance += hypothesis.word_distance
         self.word_match += hypothesis.word_match
         self.char_distance += hypothesis.char_distance
-        
+
         for word in hypothesis.match_words_frequencies:
-            self.match_words_frequencies[word] = self.match_words_frequencies.get(word, 0) + hypothesis.match_words_frequencies[word]
-                
+            self.match_words_frequencies[word] = (
+                self.match_words_frequencies.get(word, 0) + hypothesis.match_words_frequencies[word]
+            )
+
     def compute(self, dataset_num_words: int, dataset_num_chars: int):
         self.wer = round(self.word_distance / dataset_num_words * 100.0, 2)
         self.wmr = round(self.word_match / dataset_num_words * 100.0, 2)
diff --git a/tools/speech_data_explorer/sde/dataloader/engines/cudf_engine.py b/tools/speech_data_explorer/sde/dataloader/engines/cudf_engine.py
index f07fb15955ba..2e1d5a041e0e 100644
--- a/tools/speech_data_explorer/sde/dataloader/engines/cudf_engine.py
+++ b/tools/speech_data_explorer/sde/dataloader/engines/cudf_engine.py
@@ -1,4 +1,5 @@
 import cudf.pandas
+
 cudf.pandas.install()
 import pandas as pd
 
@@ -6,38 +7,42 @@
 class cuDF:
     def __init__(self):
         pass
-    
+
     def load_samples_chunk(self, samples: list[dict]):
         chunk = pd.DataFrame(samples)
         return chunk
-    
+
     def concat_samples_chunks(self, samples_chunks: list):
         samples_datatable = pd.concat(samples_chunks).reset_index(drop=True)
         return samples_datatable
-    
+
     def process_vocabulary(self, words_frequencies: dict, hypotheses_metrics: list[object]):
         vocabulary_dfs = []
-        
+
         words_frequencies_df = pd.DataFrame(words_frequencies.items(), columns=["Word", "Amount"]).set_index("Word")
         vocabulary_dfs.append(words_frequencies_df)
-                
+
         for hypothesis_metrics_obj in hypotheses_metrics:
             label = hypothesis_metrics_obj.hypothesis_label
             match_words_frequencies = hypothesis_metrics_obj.match_words_frequencies
-            match_words_frequencies_df = pd.DataFrame(match_words_frequencies.items(), columns=["Word", f"Match_{hypothesis_metrics_obj.hypothesis_label}"]).set_index("Word")
+            match_words_frequencies_df = pd.DataFrame(
+                match_words_frequencies.items(), columns=["Word", f"Match_{hypothesis_metrics_obj.hypothesis_label}"]
+            ).set_index("Word")
             vocabulary_dfs.append(match_words_frequencies_df)
-        
-        vocabulary_datatable = pd.concat(vocabulary_dfs, axis = 1, join = "outer").reset_index().fillna(0)
-        
+
+        vocabulary_datatable = pd.concat(vocabulary_dfs, axis=1, join="outer").reset_index().fillna(0)
+
         for hypothesis_metrics_obj in hypotheses_metrics:
             label = hypothesis_metrics_obj.hypothesis_label
             postfix = ""
             if label != "":
                 postfix = f"_{label}"
-            
-            vocabulary_datatable[f"Accuracy{postfix}"] = vocabulary_datatable[f"Match_{label}"] / vocabulary_datatable["Amount"] * 100
+
+            vocabulary_datatable[f"Accuracy{postfix}"] = (
+                vocabulary_datatable[f"Match_{label}"] / vocabulary_datatable["Amount"] * 100
+            )
             vocabulary_datatable[f"Accuracy{postfix}"] = vocabulary_datatable[f"Accuracy{postfix}"].round(2)
-            vocabulary_datatable = vocabulary_datatable.drop(f"Match_{label}", axis = 1)
+            vocabulary_datatable = vocabulary_datatable.drop(f"Match_{label}", axis=1)
             hypothesis_metrics_obj.mwa = round(vocabulary_datatable[f"Accuracy{postfix}"].mean(), 2)
-        
+
         return vocabulary_datatable
diff --git a/tools/speech_data_explorer/sde/dataloader/sample.py b/tools/speech_data_explorer/sde/dataloader/sample.py
index 2cebc9e0532e..ab0c22189d7e 100644
--- a/tools/speech_data_explorer/sde/dataloader/sample.py
+++ b/tools/speech_data_explorer/sde/dataloader/sample.py
@@ -1,10 +1,11 @@
 import json
 from collections import Counter
-import jiwer
 from difflib import SequenceMatcher
+
 import editdistance
-import numpy as np
+import jiwer
 import librosa
+import numpy as np
 
 
 class Sample:
@@ -31,20 +32,24 @@ def reset(self):
         self.frequency_bandwidth = None
         self.level_db = None
         self.hypotheses = {}
-    
-    def parse_line(self, manifest_line: str, reference_field: str = "text", 
-                   hypothesis_fields: list[str] = ["pred_text"], 
-                   hypothesis_labels: list[str] = None):
-        
+
+    def parse_line(
+        self,
+        manifest_line: str,
+        reference_field: str = "text",
+        hypothesis_fields: list[str] = ["pred_text"],
+        hypothesis_labels: list[str] = None,
+    ):
+
         self.sample_dict = json.loads(manifest_line)
         self.reference_text = self.sample_dict.get(reference_field, None)
         self.duration = self.sample_dict.get("duration", None)
-        
+
         if hypothesis_labels is None:
             hypothesis_labels = list(range(1, len(hypothesis_fields) + 1))
-        
+
         for field, label in zip(hypothesis_fields, hypothesis_labels):
-            hypothesis = Hypothesis(hypothesis_text = self.sample_dict[field], hypothesis_label = label)
+            hypothesis = Hypothesis(hypothesis_text=self.sample_dict[field], hypothesis_label=label)
             self.hypotheses[field] = hypothesis
 
     def compute(self, estimate_audio_metrics: bool = False):
@@ -53,24 +58,29 @@ def compute(self, estimate_audio_metrics: bool = False):
         self.num_words = len(self.words)
         self.charset = set(self.reference_text)
         self.words_frequencies = dict(Counter(self.words))
-        
+
         if self.duration is not None:
             self.char_rate = round(self.num_chars / self.duration, 2)
             self.word_rate = round(self.num_chars / self.duration, 2)
-        
+
         if len(self.hypotheses) != 0:
             for label in self.hypotheses:
-                self.hypotheses[label].compute(reference_text = self.reference_text, reference_words = self.words,
-                                               reference_num_words = self.num_words, reference_num_chars = self.num_chars)
-                    
+                self.hypotheses[label].compute(
+                    reference_text=self.reference_text,
+                    reference_words=self.words,
+                    reference_num_words=self.num_words,
+                    reference_num_chars=self.num_chars,
+                )
+
         if estimate_audio_metrics and self.audio_filepath is not None:
-            
+
             def eval_signal_frequency_bandwidth(self, signal, sampling_rate, threshold=-50) -> float:
                 time_stride = 0.01
                 hop_length = int(sampling_rate * time_stride)
                 n_fft = 512
                 spectrogram = np.mean(
-                    np.abs(librosa.stft(y=signal, n_fft=n_fft, hop_length=hop_length, window='blackmanharris')) ** 2, axis=1
+                    np.abs(librosa.stft(y=signal, n_fft=n_fft, hop_length=hop_length, window='blackmanharris')) ** 2,
+                    axis=1,
                 )
                 power_spectrum = librosa.power_to_db(S=spectrogram, ref=np.max, top_db=100)
                 frequency_bandwidth = 0
@@ -78,25 +88,27 @@ def eval_signal_frequency_bandwidth(self, signal, sampling_rate, threshold=-50)
                     if power_spectrum[idx] > threshold:
                         frequency_bandwidth = idx / n_fft * sampling_rate
                         break
-                
+
                 return frequency_bandwidth
-            
+
             self.signal, self.sampling_rate = librosa.load(path=self.audio_filepath, sr=None)
-            self.frequency_bandwidth = eval_signal_frequency_bandwidth(signal=self.signal, sampling_rate=self.sampling_rate)
+            self.frequency_bandwidth = eval_signal_frequency_bandwidth(
+                signal=self.signal, sampling_rate=self.sampling_rate
+            )
             self.level_db = 20 * np.log10(np.max(np.abs(self.signal)))
 
         self.add_table_metrics_to_dict()
-        
+
     def add_table_metrics_to_dict(self):
         metrics = {
             "num_chars": self.num_chars,
             "num_words": self.num_words,
         }
-        
+
         if self.duration is not None:
             metrics["char_rate"] = self.char_rate
             metrics["word_rate"] = self.word_rate
-        
+
         if len(self.hypotheses) != 0:
             for label in self.hypotheses:
                 hypothesis_metrics = self.hypotheses[label].get_table_metrics()
@@ -105,7 +117,7 @@ def add_table_metrics_to_dict(self):
         if self.frequency_bandwidth is not None:
             metrics["freq_bandwidth"] = self.frequency_bandwidth
             metrics["level_db"] = self.level_db
-        
+
         self.sample_dict.update(metrics)
 
 
@@ -114,7 +126,7 @@ def __init__(self, hypothesis_text: str, hypothesis_label: str = None):
         self.hypothesis_text = hypothesis_text
         self.hypothesis_label = hypothesis_label
         self.hypothesis_words = None
-        
+
         self.wer = None
         self.wmr = None
         self.num_insertions = None
@@ -123,23 +135,28 @@ def __init__(self, hypothesis_text: str, hypothesis_label: str = None):
         self.word_match = None
         self.word_distance = None
         self.match_words_frequencies = dict()
-        
+
         self.char_distance = None
         self.cer = None
-    
-    def compute(self, reference_text: str, reference_words: list[str] = None, 
-                reference_num_words: int = None, reference_num_chars: int = None):
-        
+
+    def compute(
+        self,
+        reference_text: str,
+        reference_words: list[str] = None,
+        reference_num_words: int = None,
+        reference_num_chars: int = None,
+    ):
+
         if reference_words is None:
             reference_words = reference_text.split()
         if reference_num_words is None:
             reference_num_words = len(reference_words)
         if reference_num_chars is None:
             reference_num_chars = len(reference_text)
-        
+
         self.hypothesis_words = self.hypothesis_text.split()
-            
-        #word match metrics
+
+        # word match metrics
         measures = jiwer.compute_measures(reference_text, self.hypothesis_text)
 
         self.wer = round(measures['wer'] * 100.0, 2)
@@ -149,28 +166,34 @@ def compute(self, reference_text: str, reference_words: list[str] = None,
         self.deletions_insertions_diff = self.num_deletions - self.num_insertions
         self.word_match = measures['hits']
         self.word_distance = measures['substitutions'] + measures['insertions'] + measures['deletions']
-        
+
         sm = SequenceMatcher()
         sm.set_seqs(reference_words, self.hypothesis_words)
-        self.match_words_frequencies = dict(Counter([reference_words[word_idx] 
-                                for match in sm.get_matching_blocks()
-                                for word_idx in range(match[0], match[0] + match[2])]))
-        
-        #char match metrics
+        self.match_words_frequencies = dict(
+            Counter(
+                [
+                    reference_words[word_idx]
+                    for match in sm.get_matching_blocks()
+                    for word_idx in range(match[0], match[0] + match[2])
+                ]
+            )
+        )
+
+        # char match metrics
         self.char_distance = editdistance.eval(reference_text, self.hypothesis_text)
         self.cer = round(self.char_distance / reference_num_chars * 100.0, 2)
-    
+
     def get_table_metrics(self):
         postfix = ""
         if self.hypothesis_label != "":
             postfix = f"_{self.hypothesis_label}"
-        
+
         metrics = {
-            f"WER{postfix}"    : self.wer,
-            f"CER{postfix}"    : self.cer,
-            f"WMR{postfix}"    : self.wmr,
-            f"I{postfix}"      : self.num_insertions,
-            f"D{postfix}"      : self.num_deletions,
-            f"D-I{postfix}"    : self.deletions_insertions_diff
+            f"WER{postfix}": self.wer,
+            f"CER{postfix}": self.cer,
+            f"WMR{postfix}": self.wmr,
+            f"I{postfix}": self.num_insertions,
+            f"D{postfix}": self.num_deletions,
+            f"D-I{postfix}": self.deletions_insertions_diff,
         }
         return metrics