Skip to content

Commit dcba064

Browse files
authored
Main script refactoring: metrics calculation
Sample metrics computations joined as methods of Sample class. Also added support of calculation metrics for multiple hypothesis Signed-off-by: Sasha Meister <[email protected]>
1 parent df20f08 commit dcba064

File tree

1 file changed

+176
-0
lines changed
  • tools/speech_data_explorer/sde/dataloader

1 file changed

+176
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import json
2+
from collections import Counter
3+
import jiwer
4+
from difflib import SequenceMatcher
5+
import editdistance
6+
import numpy as np
7+
import librosa
8+
9+
10+
class Sample:
11+
def __init__(self):
12+
self.reference_text = None
13+
self.num_chars = None
14+
self.charset = set()
15+
self.words = None
16+
self.num_words = None
17+
self.words_frequencies = None
18+
self.duration = None
19+
self.frequency_bandwidth = None
20+
self.level_db = None
21+
self.hypotheses = {}
22+
23+
def reset(self):
24+
self.reference_text = None
25+
self.num_chars = None
26+
self.charset = set()
27+
self.words = None
28+
self.num_words = None
29+
self.words_frequencies = None
30+
self.duration = None
31+
self.frequency_bandwidth = None
32+
self.level_db = None
33+
self.hypotheses = {}
34+
35+
def parse_line(self, manifest_line: str, reference_field: str = "text",
36+
hypothesis_fields: list[str] = ["pred_text"],
37+
hypothesis_labels: list[str] = None):
38+
39+
self.sample_dict = json.loads(manifest_line)
40+
self.reference_text = self.sample_dict.get(reference_field, None)
41+
self.duration = self.sample_dict.get("duration", None)
42+
43+
if hypothesis_labels is None:
44+
hypothesis_labels = list(range(1, len(hypothesis_fields) + 1))
45+
46+
for field, label in zip(hypothesis_fields, hypothesis_labels):
47+
hypothesis = Hypothesis(hypothesis_text = self.sample_dict[field], hypothesis_label = label)
48+
self.hypotheses[field] = hypothesis
49+
50+
def compute(self, estimate_audio_metrics: bool = False):
51+
self.num_chars = len(self.reference_text)
52+
self.words = self.reference_text.split()
53+
self.num_words = len(self.words)
54+
self.charset = set(self.reference_text)
55+
self.words_frequencies = dict(Counter(self.words))
56+
57+
if self.duration is not None:
58+
self.char_rate = round(self.num_chars / self.duration, 2)
59+
self.word_rate = round(self.num_chars / self.duration, 2)
60+
61+
if len(self.hypotheses) != 0:
62+
for label in self.hypotheses:
63+
self.hypotheses[label].compute(reference_text = self.reference_text, reference_words = self.words,
64+
reference_num_words = self.num_words, reference_num_chars = self.num_chars)
65+
66+
if estimate_audio_metrics and self.audio_filepath is not None:
67+
68+
def eval_signal_frequency_bandwidth(self, signal, sampling_rate, threshold=-50) -> float:
69+
time_stride = 0.01
70+
hop_length = int(sampling_rate * time_stride)
71+
n_fft = 512
72+
spectrogram = np.mean(
73+
np.abs(librosa.stft(y=signal, n_fft=n_fft, hop_length=hop_length, window='blackmanharris')) ** 2, axis=1
74+
)
75+
power_spectrum = librosa.power_to_db(S=spectrogram, ref=np.max, top_db=100)
76+
frequency_bandwidth = 0
77+
for idx in range(len(power_spectrum) - 1, -1, -1):
78+
if power_spectrum[idx] > threshold:
79+
frequency_bandwidth = idx / n_fft * sampling_rate
80+
break
81+
82+
return frequency_bandwidth
83+
84+
self.signal, self.sampling_rate = librosa.load(path=self.audio_filepath, sr=None)
85+
self.frequency_bandwidth = eval_signal_frequency_bandwidth(signal=self.signal, sampling_rate=self.sampling_rate)
86+
self.level_db = 20 * np.log10(np.max(np.abs(self.signal)))
87+
88+
self.add_table_metrics_to_dict()
89+
90+
def add_table_metrics_to_dict(self):
91+
metrics = {
92+
"num_chars": self.num_chars,
93+
"num_words": self.num_words,
94+
}
95+
96+
if self.duration is not None:
97+
metrics["char_rate"] = self.char_rate
98+
metrics["word_rate"] = self.word_rate
99+
100+
if len(self.hypotheses) != 0:
101+
for label in self.hypotheses:
102+
hypothesis_metrics = self.hypotheses[label].get_table_metrics()
103+
metrics.update(hypothesis_metrics)
104+
105+
if self.frequency_bandwidth is not None:
106+
metrics["freq_bandwidth"] = self.frequency_bandwidth
107+
metrics["level_db"] = self.level_db
108+
109+
self.sample_dict.update(metrics)
110+
111+
112+
class Hypothesis:
113+
def __init__(self, hypothesis_text: str, hypothesis_label: str = None):
114+
self.hypothesis_text = hypothesis_text
115+
self.hypothesis_label = hypothesis_label
116+
self.hypothesis_words = None
117+
118+
self.wer = None
119+
self.wmr = None
120+
self.num_insertions = None
121+
self.num_deletions = None
122+
self.deletions_insertions_diff = None
123+
self.word_match = None
124+
self.word_distance = None
125+
self.match_words_frequencies = dict()
126+
127+
self.char_distance = None
128+
self.cer = None
129+
130+
def compute(self, reference_text: str, reference_words: list[str] = None,
131+
reference_num_words: int = None, reference_num_chars: int = None):
132+
133+
if reference_words is None:
134+
reference_words = reference_text.split()
135+
if reference_num_words is None:
136+
reference_num_words = len(reference_words)
137+
if reference_num_chars is None:
138+
reference_num_chars = len(reference_text)
139+
140+
self.hypothesis_words = self.hypothesis_text.split()
141+
142+
#word match metrics
143+
measures = jiwer.compute_measures(reference_text, self.hypothesis_text)
144+
145+
self.wer = round(measures['wer'] * 100.0, 2)
146+
self.wmr = round(measures['hits'] / reference_num_words * 100.0, 2)
147+
self.num_insertions = measures['insertions']
148+
self.num_deletions = measures['deletions']
149+
self.deletions_insertions_diff = self.num_deletions - self.num_insertions
150+
self.word_match = measures['hits']
151+
self.word_distance = measures['substitutions'] + measures['insertions'] + measures['deletions']
152+
153+
sm = SequenceMatcher()
154+
sm.set_seqs(reference_words, self.hypothesis_words)
155+
self.match_words_frequencies = dict(Counter([reference_words[word_idx]
156+
for match in sm.get_matching_blocks()
157+
for word_idx in range(match[0], match[0] + match[2])]))
158+
159+
#char match metrics
160+
self.char_distance = editdistance.eval(reference_text, self.hypothesis_text)
161+
self.cer = round(self.char_distance / reference_num_chars * 100.0, 2)
162+
163+
def get_table_metrics(self):
164+
postfix = ""
165+
if self.hypothesis_label != "":
166+
postfix = f"_{self.hypothesis_label}"
167+
168+
metrics = {
169+
f"WER{postfix}" : self.wer,
170+
f"CER{postfix}" : self.cer,
171+
f"WMR{postfix}" : self.wmr,
172+
f"I{postfix}" : self.num_insertions,
173+
f"D{postfix}" : self.num_deletions,
174+
f"D-I{postfix}" : self.deletions_insertions_diff
175+
}
176+
return metrics

0 commit comments

Comments
 (0)