1
1
import json
2
2
from collections import Counter
3
+ import jiwer
3
4
from difflib import SequenceMatcher
4
-
5
5
import editdistance
6
- import jiwer
7
- import librosa
8
6
import numpy as np
7
+ import librosa
9
8
10
9
11
10
class Sample :
11
+ """
12
+ A class representing a sample of data, including reference and hypothesis texts, for processing and analysis.
13
+
14
+ Attributes:
15
+ - reference_text (str): The reference text associated with the sample.
16
+ - num_chars (int): Number of characters in the reference text.
17
+ - charset (set): Set of unique characters in the reference text.
18
+ - words (list): List of words in the reference text.
19
+ - num_words (int): Number of words in the reference text.
20
+ - words_frequencies (dict): Dictionary containing word frequencies in the reference text.
21
+ - duration (float): Duration of the audio in the sample.
22
+ - frequency_bandwidth (float): Frequency bandwidth of the audio signal (computed if audio file provided).
23
+ - level_db (float): Level of the audio signal in decibels (computed if audio file provided).
24
+ - hypotheses (dict): Dictionary containing hypothesis objects for different fields.
25
+
26
+ Methods:
27
+ - reset():
28
+ Resets the sample attributes to their initial state.
29
+
30
+ - parse_line(manifest_line: str, reference_field: str = "text",
31
+ hypothesis_fields: list[str] = ["pred_text"],
32
+ hypothesis_labels: list[str] = None):
33
+ Parses a line from the manifest file and updates the sample information.
34
+
35
+ - compute(estimate_audio_metrics: bool = False):
36
+ Computes metrics for the sample, including word frequencies and audio metrics if specified.
37
+
38
+ - add_table_metrics_to_dict():
39
+ Adds computed metrics to the sample dictionary.
40
+ """
41
+
12
42
def __init__ (self ):
13
43
self .reference_text = None
14
44
self .num_chars = None
@@ -22,6 +52,10 @@ def __init__(self):
22
52
self .hypotheses = {}
23
53
24
54
def reset (self ):
55
+ """
56
+ Resets the sample attributes to their initial state.
57
+ """
58
+
25
59
self .reference_text = None
26
60
self .num_chars = None
27
61
self .charset = set ()
@@ -32,83 +66,86 @@ def reset(self):
32
66
self .frequency_bandwidth = None
33
67
self .level_db = None
34
68
self .hypotheses = {}
35
-
36
- def parse_line (
37
- self ,
38
- manifest_line : str ,
39
- reference_field : str = "text" ,
40
- hypothesis_fields : list [str ] = ["pred_text" ],
41
- hypothesis_labels : list [str ] = None ,
42
- ):
43
-
69
+
70
+ def parse_line (self , manifest_line : str , reference_field : str = "text" ,
71
+ hypothesis_fields : list [str ] = ["pred_text" ],
72
+ hypothesis_labels : list [str ] = None ):
73
+ """
74
+ Parses a line from the manifest file and updates the sample information.
75
+ """
76
+
44
77
self .sample_dict = json .loads (manifest_line )
45
78
self .reference_text = self .sample_dict .get (reference_field , None )
46
79
self .duration = self .sample_dict .get ("duration" , None )
47
-
80
+
48
81
if hypothesis_labels is None :
49
82
hypothesis_labels = list (range (1 , len (hypothesis_fields ) + 1 ))
50
-
83
+
51
84
for field , label in zip (hypothesis_fields , hypothesis_labels ):
52
- hypothesis = Hypothesis (hypothesis_text = self .sample_dict [field ], hypothesis_label = label )
85
+ hypothesis = Hypothesis (hypothesis_text = self .sample_dict [field ], hypothesis_label = label )
53
86
self .hypotheses [field ] = hypothesis
54
87
55
88
def compute (self , estimate_audio_metrics : bool = False ):
89
+ """
90
+ Computes metrics for the sample, including word frequencies and audio metrics if specified.
91
+
92
+ Parameters:
93
+ - estimate_audio_metrics (bool): Flag indicating whether to estimate audio metrics (default is False).
94
+ """
95
+
56
96
self .num_chars = len (self .reference_text )
57
97
self .words = self .reference_text .split ()
58
98
self .num_words = len (self .words )
59
99
self .charset = set (self .reference_text )
60
100
self .words_frequencies = dict (Counter (self .words ))
61
-
101
+
62
102
if self .duration is not None :
63
103
self .char_rate = round (self .num_chars / self .duration , 2 )
64
104
self .word_rate = round (self .num_chars / self .duration , 2 )
65
-
105
+
66
106
if len (self .hypotheses ) != 0 :
67
107
for label in self .hypotheses :
68
- self .hypotheses [label ].compute (
69
- reference_text = self .reference_text ,
70
- reference_words = self .words ,
71
- reference_num_words = self .num_words ,
72
- reference_num_chars = self .num_chars ,
73
- )
74
-
108
+ self .hypotheses [label ].compute (reference_text = self .reference_text , reference_words = self .words ,
109
+ reference_num_words = self .num_words , reference_num_chars = self .num_chars )
110
+
75
111
if estimate_audio_metrics and self .audio_filepath is not None :
76
-
112
+
77
113
def eval_signal_frequency_bandwidth (self , signal , sampling_rate , threshold = - 50 ) -> float :
78
114
time_stride = 0.01
79
115
hop_length = int (sampling_rate * time_stride )
80
116
n_fft = 512
81
117
spectrogram = np .mean (
82
- np .abs (librosa .stft (y = signal , n_fft = n_fft , hop_length = hop_length , window = 'blackmanharris' )) ** 2 ,
83
- axis = 1 ,
118
+ np .abs (librosa .stft (y = signal , n_fft = n_fft , hop_length = hop_length , window = 'blackmanharris' )) ** 2 , axis = 1
84
119
)
85
120
power_spectrum = librosa .power_to_db (S = spectrogram , ref = np .max , top_db = 100 )
86
121
frequency_bandwidth = 0
87
122
for idx in range (len (power_spectrum ) - 1 , - 1 , - 1 ):
88
123
if power_spectrum [idx ] > threshold :
89
124
frequency_bandwidth = idx / n_fft * sampling_rate
90
125
break
91
-
126
+
92
127
return frequency_bandwidth
93
-
128
+
94
129
self .signal , self .sampling_rate = librosa .load (path = self .audio_filepath , sr = None )
95
- self .frequency_bandwidth = eval_signal_frequency_bandwidth (
96
- signal = self .signal , sampling_rate = self .sampling_rate
97
- )
130
+ self .frequency_bandwidth = eval_signal_frequency_bandwidth (signal = self .signal , sampling_rate = self .sampling_rate )
98
131
self .level_db = 20 * np .log10 (np .max (np .abs (self .signal )))
99
132
100
133
self .add_table_metrics_to_dict ()
101
-
134
+
102
135
def add_table_metrics_to_dict (self ):
136
+ """
137
+ Adds computed metrics to the sample dictionary.
138
+ """
139
+
103
140
metrics = {
104
141
"num_chars" : self .num_chars ,
105
142
"num_words" : self .num_words ,
106
143
}
107
-
144
+
108
145
if self .duration is not None :
109
146
metrics ["char_rate" ] = self .char_rate
110
147
metrics ["word_rate" ] = self .word_rate
111
-
148
+
112
149
if len (self .hypotheses ) != 0 :
113
150
for label in self .hypotheses :
114
151
hypothesis_metrics = self .hypotheses [label ].get_table_metrics ()
@@ -117,16 +154,47 @@ def add_table_metrics_to_dict(self):
117
154
if self .frequency_bandwidth is not None :
118
155
metrics ["freq_bandwidth" ] = self .frequency_bandwidth
119
156
metrics ["level_db" ] = self .level_db
120
-
157
+
121
158
self .sample_dict .update (metrics )
122
159
123
160
124
161
class Hypothesis :
162
+ """
163
+ A class representing a hypothesis for evaluating speech-related data.
164
+
165
+ Parameters:
166
+ - hypothesis_text (str): The text of the hypothesis.
167
+ - hypothesis_label (str): Label associated with the hypothesis (default is None).
168
+
169
+ Attributes:
170
+ - hypothesis_text (str): The text of the hypothesis.
171
+ - hypothesis_label (str): Label associated with the hypothesis.
172
+ - hypothesis_words (list): List of words in the hypothesis text.
173
+ - wer (float): Word Error Rate metric.
174
+ - wmr (float): Word Match Rate metric.
175
+ - num_insertions (int): Number of insertions in the hypothesis.
176
+ - num_deletions (int): Number of deletions in the hypothesis.
177
+ - deletions_insertions_diff (int): Difference between deletions and insertions.
178
+ - word_match (int): Number of word matches in the hypothesis.
179
+ - word_distance (int): Total word distance in the hypothesis.
180
+ - match_words_frequencies (dict): Dictionary containing frequencies of matching words.
181
+ - char_distance (int): Total character distance in the hypothesis.
182
+ - cer (float): Character Error Rate metric.
183
+
184
+ Methods:
185
+ - compute(reference_text: str, reference_words: list[str], reference_num_words: int, reference_num_chars: int):
186
+ Computes metrics for the hypothesis based on a reference text.
187
+
188
+ - get_table_metrics() -> dict:
189
+ Returns a dictionary containing computed metrics suitable for tabular presentation.
190
+
191
+ """
192
+
125
193
def __init__ (self , hypothesis_text : str , hypothesis_label : str = None ):
126
194
self .hypothesis_text = hypothesis_text
127
195
self .hypothesis_label = hypothesis_label
128
196
self .hypothesis_words = None
129
-
197
+
130
198
self .wer = None
131
199
self .wmr = None
132
200
self .num_insertions = None
@@ -135,28 +203,32 @@ def __init__(self, hypothesis_text: str, hypothesis_label: str = None):
135
203
self .word_match = None
136
204
self .word_distance = None
137
205
self .match_words_frequencies = dict ()
138
-
206
+
139
207
self .char_distance = None
140
208
self .cer = None
141
-
142
- def compute (
143
- self ,
144
- reference_text : str ,
145
- reference_words : list [str ] = None ,
146
- reference_num_words : int = None ,
147
- reference_num_chars : int = None ,
148
- ):
149
-
209
+
210
+ def compute (self , reference_text : str , reference_words : list [str ] = None ,
211
+ reference_num_words : int = None , reference_num_chars : int = None ):
212
+ """
213
+ Computes metrics for the hypothesis based on a reference text.
214
+
215
+ Parameters:
216
+ - reference_text (str): The reference text for comparison.
217
+ - reference_words (list[str]): List of words in the reference text (default is None).
218
+ - reference_num_words (int): Number of words in the reference text (default is None).
219
+ - reference_num_chars (int): Number of characters in the reference text (default is None).
220
+ """
221
+
150
222
if reference_words is None :
151
223
reference_words = reference_text .split ()
152
224
if reference_num_words is None :
153
225
reference_num_words = len (reference_words )
154
226
if reference_num_chars is None :
155
227
reference_num_chars = len (reference_text )
156
-
228
+
157
229
self .hypothesis_words = self .hypothesis_text .split ()
158
-
159
- # word match metrics
230
+
231
+ #word match metrics
160
232
measures = jiwer .compute_measures (reference_text , self .hypothesis_text )
161
233
162
234
self .wer = round (measures ['wer' ] * 100.0 , 2 )
@@ -166,34 +238,35 @@ def compute(
166
238
self .deletions_insertions_diff = self .num_deletions - self .num_insertions
167
239
self .word_match = measures ['hits' ]
168
240
self .word_distance = measures ['substitutions' ] + measures ['insertions' ] + measures ['deletions' ]
169
-
241
+
170
242
sm = SequenceMatcher ()
171
243
sm .set_seqs (reference_words , self .hypothesis_words )
172
- self .match_words_frequencies = dict (
173
- Counter (
174
- [
175
- reference_words [word_idx ]
176
- for match in sm .get_matching_blocks ()
177
- for word_idx in range (match [0 ], match [0 ] + match [2 ])
178
- ]
179
- )
180
- )
181
-
182
- # char match metrics
244
+ self .match_words_frequencies = dict (Counter ([reference_words [word_idx ]
245
+ for match in sm .get_matching_blocks ()
246
+ for word_idx in range (match [0 ], match [0 ] + match [2 ])]))
247
+
248
+ #char match metrics
183
249
self .char_distance = editdistance .eval (reference_text , self .hypothesis_text )
184
250
self .cer = round (self .char_distance / reference_num_chars * 100.0 , 2 )
185
-
251
+
186
252
def get_table_metrics (self ):
253
+ """
254
+ Returns a dictionary containing computed metrics.
255
+
256
+ Returns:
257
+ - dict: A dictionary containing computed metrics.
258
+ """
259
+
187
260
postfix = ""
188
261
if self .hypothesis_label != "" :
189
262
postfix = f"_{ self .hypothesis_label } "
190
-
263
+
191
264
metrics = {
192
- f"WER{ postfix } " : self .wer ,
193
- f"CER{ postfix } " : self .cer ,
194
- f"WMR{ postfix } " : self .wmr ,
195
- f"I{ postfix } " : self .num_insertions ,
196
- f"D{ postfix } " : self .num_deletions ,
197
- f"D-I{ postfix } " : self .deletions_insertions_diff ,
265
+ f"WER{ postfix } " : self .wer ,
266
+ f"CER{ postfix } " : self .cer ,
267
+ f"WMR{ postfix } " : self .wmr ,
268
+ f"I{ postfix } " : self .num_insertions ,
269
+ f"D{ postfix } " : self .num_deletions ,
270
+ f"D-I{ postfix } " : self .deletions_insertions_diff
198
271
}
199
272
return metrics
0 commit comments