Skip to content

Commit fc59941

Browse files
authored
Description added
Signed-off-by: Sasha Meister <[email protected]>
1 parent f08dc9f commit fc59941

File tree

1 file changed

+143
-70
lines changed
  • tools/speech_data_explorer/sde/dataloader

1 file changed

+143
-70
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,44 @@
11
import json
22
from collections import Counter
3+
import jiwer
34
from difflib import SequenceMatcher
4-
55
import editdistance
6-
import jiwer
7-
import librosa
86
import numpy as np
7+
import librosa
98

109

1110
class Sample:
11+
"""
12+
A class representing a sample of data, including reference and hypothesis texts, for processing and analysis.
13+
14+
Attributes:
15+
- reference_text (str): The reference text associated with the sample.
16+
- num_chars (int): Number of characters in the reference text.
17+
- charset (set): Set of unique characters in the reference text.
18+
- words (list): List of words in the reference text.
19+
- num_words (int): Number of words in the reference text.
20+
- words_frequencies (dict): Dictionary containing word frequencies in the reference text.
21+
- duration (float): Duration of the audio in the sample.
22+
- frequency_bandwidth (float): Frequency bandwidth of the audio signal (computed if audio file provided).
23+
- level_db (float): Level of the audio signal in decibels (computed if audio file provided).
24+
- hypotheses (dict): Dictionary containing hypothesis objects for different fields.
25+
26+
Methods:
27+
- reset():
28+
Resets the sample attributes to their initial state.
29+
30+
- parse_line(manifest_line: str, reference_field: str = "text",
31+
hypothesis_fields: list[str] = ["pred_text"],
32+
hypothesis_labels: list[str] = None):
33+
Parses a line from the manifest file and updates the sample information.
34+
35+
- compute(estimate_audio_metrics: bool = False):
36+
Computes metrics for the sample, including word frequencies and audio metrics if specified.
37+
38+
- add_table_metrics_to_dict():
39+
Adds computed metrics to the sample dictionary.
40+
"""
41+
1242
def __init__(self):
1343
self.reference_text = None
1444
self.num_chars = None
@@ -22,6 +52,10 @@ def __init__(self):
2252
self.hypotheses = {}
2353

2454
def reset(self):
55+
"""
56+
Resets the sample attributes to their initial state.
57+
"""
58+
2559
self.reference_text = None
2660
self.num_chars = None
2761
self.charset = set()
@@ -32,83 +66,86 @@ def reset(self):
3266
self.frequency_bandwidth = None
3367
self.level_db = None
3468
self.hypotheses = {}
35-
36-
def parse_line(
37-
self,
38-
manifest_line: str,
39-
reference_field: str = "text",
40-
hypothesis_fields: list[str] = ["pred_text"],
41-
hypothesis_labels: list[str] = None,
42-
):
43-
69+
70+
def parse_line(self, manifest_line: str, reference_field: str = "text",
71+
hypothesis_fields: list[str] = ["pred_text"],
72+
hypothesis_labels: list[str] = None):
73+
"""
74+
Parses a line from the manifest file and updates the sample information.
75+
"""
76+
4477
self.sample_dict = json.loads(manifest_line)
4578
self.reference_text = self.sample_dict.get(reference_field, None)
4679
self.duration = self.sample_dict.get("duration", None)
47-
80+
4881
if hypothesis_labels is None:
4982
hypothesis_labels = list(range(1, len(hypothesis_fields) + 1))
50-
83+
5184
for field, label in zip(hypothesis_fields, hypothesis_labels):
52-
hypothesis = Hypothesis(hypothesis_text=self.sample_dict[field], hypothesis_label=label)
85+
hypothesis = Hypothesis(hypothesis_text = self.sample_dict[field], hypothesis_label = label)
5386
self.hypotheses[field] = hypothesis
5487

5588
def compute(self, estimate_audio_metrics: bool = False):
89+
"""
90+
Computes metrics for the sample, including word frequencies and audio metrics if specified.
91+
92+
Parameters:
93+
- estimate_audio_metrics (bool): Flag indicating whether to estimate audio metrics (default is False).
94+
"""
95+
5696
self.num_chars = len(self.reference_text)
5797
self.words = self.reference_text.split()
5898
self.num_words = len(self.words)
5999
self.charset = set(self.reference_text)
60100
self.words_frequencies = dict(Counter(self.words))
61-
101+
62102
if self.duration is not None:
63103
self.char_rate = round(self.num_chars / self.duration, 2)
64104
self.word_rate = round(self.num_chars / self.duration, 2)
65-
105+
66106
if len(self.hypotheses) != 0:
67107
for label in self.hypotheses:
68-
self.hypotheses[label].compute(
69-
reference_text=self.reference_text,
70-
reference_words=self.words,
71-
reference_num_words=self.num_words,
72-
reference_num_chars=self.num_chars,
73-
)
74-
108+
self.hypotheses[label].compute(reference_text = self.reference_text, reference_words = self.words,
109+
reference_num_words = self.num_words, reference_num_chars = self.num_chars)
110+
75111
if estimate_audio_metrics and self.audio_filepath is not None:
76-
112+
77113
def eval_signal_frequency_bandwidth(self, signal, sampling_rate, threshold=-50) -> float:
78114
time_stride = 0.01
79115
hop_length = int(sampling_rate * time_stride)
80116
n_fft = 512
81117
spectrogram = np.mean(
82-
np.abs(librosa.stft(y=signal, n_fft=n_fft, hop_length=hop_length, window='blackmanharris')) ** 2,
83-
axis=1,
118+
np.abs(librosa.stft(y=signal, n_fft=n_fft, hop_length=hop_length, window='blackmanharris')) ** 2, axis=1
84119
)
85120
power_spectrum = librosa.power_to_db(S=spectrogram, ref=np.max, top_db=100)
86121
frequency_bandwidth = 0
87122
for idx in range(len(power_spectrum) - 1, -1, -1):
88123
if power_spectrum[idx] > threshold:
89124
frequency_bandwidth = idx / n_fft * sampling_rate
90125
break
91-
126+
92127
return frequency_bandwidth
93-
128+
94129
self.signal, self.sampling_rate = librosa.load(path=self.audio_filepath, sr=None)
95-
self.frequency_bandwidth = eval_signal_frequency_bandwidth(
96-
signal=self.signal, sampling_rate=self.sampling_rate
97-
)
130+
self.frequency_bandwidth = eval_signal_frequency_bandwidth(signal=self.signal, sampling_rate=self.sampling_rate)
98131
self.level_db = 20 * np.log10(np.max(np.abs(self.signal)))
99132

100133
self.add_table_metrics_to_dict()
101-
134+
102135
def add_table_metrics_to_dict(self):
136+
"""
137+
Adds computed metrics to the sample dictionary.
138+
"""
139+
103140
metrics = {
104141
"num_chars": self.num_chars,
105142
"num_words": self.num_words,
106143
}
107-
144+
108145
if self.duration is not None:
109146
metrics["char_rate"] = self.char_rate
110147
metrics["word_rate"] = self.word_rate
111-
148+
112149
if len(self.hypotheses) != 0:
113150
for label in self.hypotheses:
114151
hypothesis_metrics = self.hypotheses[label].get_table_metrics()
@@ -117,16 +154,47 @@ def add_table_metrics_to_dict(self):
117154
if self.frequency_bandwidth is not None:
118155
metrics["freq_bandwidth"] = self.frequency_bandwidth
119156
metrics["level_db"] = self.level_db
120-
157+
121158
self.sample_dict.update(metrics)
122159

123160

124161
class Hypothesis:
162+
"""
163+
A class representing a hypothesis for evaluating speech-related data.
164+
165+
Parameters:
166+
- hypothesis_text (str): The text of the hypothesis.
167+
- hypothesis_label (str): Label associated with the hypothesis (default is None).
168+
169+
Attributes:
170+
- hypothesis_text (str): The text of the hypothesis.
171+
- hypothesis_label (str): Label associated with the hypothesis.
172+
- hypothesis_words (list): List of words in the hypothesis text.
173+
- wer (float): Word Error Rate metric.
174+
- wmr (float): Word Match Rate metric.
175+
- num_insertions (int): Number of insertions in the hypothesis.
176+
- num_deletions (int): Number of deletions in the hypothesis.
177+
- deletions_insertions_diff (int): Difference between deletions and insertions.
178+
- word_match (int): Number of word matches in the hypothesis.
179+
- word_distance (int): Total word distance in the hypothesis.
180+
- match_words_frequencies (dict): Dictionary containing frequencies of matching words.
181+
- char_distance (int): Total character distance in the hypothesis.
182+
- cer (float): Character Error Rate metric.
183+
184+
Methods:
185+
- compute(reference_text: str, reference_words: list[str], reference_num_words: int, reference_num_chars: int):
186+
Computes metrics for the hypothesis based on a reference text.
187+
188+
- get_table_metrics() -> dict:
189+
Returns a dictionary containing computed metrics suitable for tabular presentation.
190+
191+
"""
192+
125193
def __init__(self, hypothesis_text: str, hypothesis_label: str = None):
126194
self.hypothesis_text = hypothesis_text
127195
self.hypothesis_label = hypothesis_label
128196
self.hypothesis_words = None
129-
197+
130198
self.wer = None
131199
self.wmr = None
132200
self.num_insertions = None
@@ -135,28 +203,32 @@ def __init__(self, hypothesis_text: str, hypothesis_label: str = None):
135203
self.word_match = None
136204
self.word_distance = None
137205
self.match_words_frequencies = dict()
138-
206+
139207
self.char_distance = None
140208
self.cer = None
141-
142-
def compute(
143-
self,
144-
reference_text: str,
145-
reference_words: list[str] = None,
146-
reference_num_words: int = None,
147-
reference_num_chars: int = None,
148-
):
149-
209+
210+
def compute(self, reference_text: str, reference_words: list[str] = None,
211+
reference_num_words: int = None, reference_num_chars: int = None):
212+
"""
213+
Computes metrics for the hypothesis based on a reference text.
214+
215+
Parameters:
216+
- reference_text (str): The reference text for comparison.
217+
- reference_words (list[str]): List of words in the reference text (default is None).
218+
- reference_num_words (int): Number of words in the reference text (default is None).
219+
- reference_num_chars (int): Number of characters in the reference text (default is None).
220+
"""
221+
150222
if reference_words is None:
151223
reference_words = reference_text.split()
152224
if reference_num_words is None:
153225
reference_num_words = len(reference_words)
154226
if reference_num_chars is None:
155227
reference_num_chars = len(reference_text)
156-
228+
157229
self.hypothesis_words = self.hypothesis_text.split()
158-
159-
# word match metrics
230+
231+
#word match metrics
160232
measures = jiwer.compute_measures(reference_text, self.hypothesis_text)
161233

162234
self.wer = round(measures['wer'] * 100.0, 2)
@@ -166,34 +238,35 @@ def compute(
166238
self.deletions_insertions_diff = self.num_deletions - self.num_insertions
167239
self.word_match = measures['hits']
168240
self.word_distance = measures['substitutions'] + measures['insertions'] + measures['deletions']
169-
241+
170242
sm = SequenceMatcher()
171243
sm.set_seqs(reference_words, self.hypothesis_words)
172-
self.match_words_frequencies = dict(
173-
Counter(
174-
[
175-
reference_words[word_idx]
176-
for match in sm.get_matching_blocks()
177-
for word_idx in range(match[0], match[0] + match[2])
178-
]
179-
)
180-
)
181-
182-
# char match metrics
244+
self.match_words_frequencies = dict(Counter([reference_words[word_idx]
245+
for match in sm.get_matching_blocks()
246+
for word_idx in range(match[0], match[0] + match[2])]))
247+
248+
#char match metrics
183249
self.char_distance = editdistance.eval(reference_text, self.hypothesis_text)
184250
self.cer = round(self.char_distance / reference_num_chars * 100.0, 2)
185-
251+
186252
def get_table_metrics(self):
253+
"""
254+
Returns a dictionary containing computed metrics.
255+
256+
Returns:
257+
- dict: A dictionary containing computed metrics.
258+
"""
259+
187260
postfix = ""
188261
if self.hypothesis_label != "":
189262
postfix = f"_{self.hypothesis_label}"
190-
263+
191264
metrics = {
192-
f"WER{postfix}": self.wer,
193-
f"CER{postfix}": self.cer,
194-
f"WMR{postfix}": self.wmr,
195-
f"I{postfix}": self.num_insertions,
196-
f"D{postfix}": self.num_deletions,
197-
f"D-I{postfix}": self.deletions_insertions_diff,
265+
f"WER{postfix}" : self.wer,
266+
f"CER{postfix}" : self.cer,
267+
f"WMR{postfix}" : self.wmr,
268+
f"I{postfix}" : self.num_insertions,
269+
f"D{postfix}" : self.num_deletions,
270+
f"D-I{postfix}" : self.deletions_insertions_diff
198271
}
199272
return metrics

0 commit comments

Comments
 (0)