-
Notifications
You must be signed in to change notification settings - Fork 2.3k
/
rnnt_wer_bpe.py
430 lines (342 loc) · 21.5 KB
/
rnnt_wer_bpe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Union
import editdistance
import torch
from torchmetrics import Metric
from nemo.collections.asr.metrics.rnnt_wer import AbstractRNNTDecoding, RNNTDecodingConfig
from nemo.collections.asr.metrics.wer import move_dimension_to_the_front
from nemo.collections.asr.parts.submodules import rnnt_beam_decoding
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses
from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.utils import logging
__all__ = ['RNNTBPEDecoding', 'RNNTBPEWER']
class RNNTBPEDecoding(AbstractRNNTDecoding):
"""
Used for performing RNN-T auto-regressive decoding of the Decoder+Joint network given the encoder state.
Args:
decoding_cfg: A dict-like object which contains the following key-value pairs.
strategy: str value which represents the type of decoding that can occur.
Possible values are :
- greedy, greedy_batch (for greedy decoding).
- beam, tsd, alsd (for beam search decoding).
compute_hypothesis_token_set: A bool flag, which determines whether to compute a list of decoded
tokens as well as the decoded string. Default is False in order to avoid double decoding
unless required.
preserve_alignments: Bool flag which preserves the history of logprobs generated during
decoding (sample / batched). When set to true, the Hypothesis will contain
the non-null value for `alignments` in it. Here, `alignments` is a List of List of
Tuple(Tensor (of length V + 1), Tensor(scalar, label after argmax)).
In order to obtain this hypothesis, please utilize `rnnt_decoder_predictions_tensor` function
with the `return_hypotheses` flag set to True.
The length of the list corresponds to the Acoustic Length (T).
Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary.
U is the number of target tokens for the current timestep Ti.
compute_timestamps: A bool flag, which determines whether to compute the character/subword, or
word based timestamp mapping the output log-probabilities to discrete intervals of timestamps.
The timestamps will be available in the returned Hypothesis.timestep as a dictionary.
compute_langs: a bool flag, which allows to compute language id (LID) information per token,
word, and the entire sample (most likely language id). The LIDS will be available
in the returned Hypothesis object as a dictionary
rnnt_timestamp_type: A str value, which represents the types of timestamps that should be calculated.
Can take the following values - "char" for character/subword time stamps, "word" for word level
time stamps and "all" (default), for both character level and word level time stamps.
word_seperator: Str token representing the seperator between words.
preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores
generated during decoding (sample / batched). When set to true, the Hypothesis will contain
the non-null value for `frame_confidence` in it. Here, `alignments` is a List of List of ints.
confidence_cfg: A dict-like object which contains the following key-value pairs related to confidence
scores. In order to obtain hypotheses with confidence scores, please utilize
`rnnt_decoder_predictions_tensor` function with the `preserve_frame_confidence` flag set to True.
preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores
generated during decoding (sample / batched). When set to true, the Hypothesis will contain
the non-null value for `frame_confidence` in it. Here, `alignments` is a List of List of floats.
The length of the list corresponds to the Acoustic Length (T).
Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more confidence scores.
U is the number of target tokens for the current timestep Ti.
preserve_token_confidence: Bool flag which preserves the history of per-token confidence scores
generated during greedy decoding (sample / batched). When set to true, the Hypothesis will contain
the non-null value for `token_confidence` in it. Here, `token_confidence` is a List of floats.
The length of the list corresponds to the number of recognized tokens.
preserve_word_confidence: Bool flag which preserves the history of per-word confidence scores
generated during greedy decoding (sample / batched). When set to true, the Hypothesis will contain
the non-null value for `word_confidence` in it. Here, `word_confidence` is a List of floats.
The length of the list corresponds to the number of recognized words.
exclude_blank: Bool flag indicating that blank token confidence scores are to be excluded
from the `token_confidence`.
aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence.
Valid options are `mean`, `min`, `max`, `prod`.
method_cfg: A dict-like object which contains the method name and settings to compute per-frame
confidence scores.
name: The method name (str).
Supported values:
- 'max_prob' for using the maximum token probability as a confidence.
- 'entropy' for using a normalized entropy of a log-likelihood vector.
entropy_type: Which type of entropy to use (str).
Used if confidence_method_cfg.name is set to `entropy`.
Supported values:
- 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided,
the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)).
Note that for this entropy, the alpha should comply the following inequality:
(log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1)
where V is the model vocabulary size.
- 'tsallis' for the Tsallis entropy with the Boltzmann constant one.
Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)),
where α is a parameter. When α == 1, it works like the Gibbs entropy.
More: https://en.wikipedia.org/wiki/Tsallis_entropy
- 'renyi' for the Rényi entropy.
Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)),
where α is a parameter. When α == 1, it works like the Gibbs entropy.
More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy
alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0.
When the alpha equals one, scaling is not applied to 'max_prob',
and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i))
entropy_norm: A mapping of the entropy value to the interval [0,1].
Supported values:
- 'lin' for using the linear mapping.
- 'exp' for using exponential mapping with linear shift.
The config may further contain the following sub-dictionaries:
"greedy":
max_symbols: int, describing the maximum number of target tokens to decode per
timestep during greedy decoding. Setting to larger values allows longer sentences
to be decoded, at the cost of increased execution time.
preserve_frame_confidence: Same as above, overrides above value.
confidence_method_cfg: Same as above, overrides confidence_cfg.method_cfg.
"beam":
beam_size: int, defining the beam size for beam search. Must be >= 1.
If beam_size == 1, will perform cached greedy search. This might be slightly different
results compared to the greedy search above.
score_norm: optional bool, whether to normalize the returned beam score in the hypotheses.
Set to True by default.
return_best_hypothesis: optional bool, whether to return just the best hypothesis or all of the
hypotheses after beam search has concluded.
tsd_max_sym_exp: optional int, determines number of symmetric expansions of the target symbols
per timestep of the acoustic model. Larger values will allow longer sentences to be decoded,
at increased cost to execution time.
alsd_max_target_len: optional int or float, determines the potential maximum target sequence length.
If an integer is provided, it can decode sequences of that particular maximum length.
If a float is provided, it can decode sequences of int(alsd_max_target_len * seq_len),
where seq_len is the length of the acoustic model output (T).
NOTE:
If a float is provided, it can be greater than 1!
By default, a float of 2.0 is used so that a target sequence can be at most twice
as long as the acoustic model output length T.
maes_num_steps: Number of adaptive steps to take. From the paper, 2 steps is generally sufficient,
and can be reduced to 1 to improve decoding speed while sacrificing some accuracy. int > 0.
maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to keep this as 1
in order to reduce expensive beam search cost later. int >= 0.
maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size.
Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0,
and affects the speed of inference since large values will perform large beam search in the next step.
maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions.
The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v])
where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be
predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for
expansion apart from the "most likely" candidate.
Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed
but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value,
thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally
tuned on a validation set.
softmax_temperature: Scales the logits of the joint prior to computing log_softmax.
decoder: The Decoder/Prediction network module.
joint: The Joint network module.
tokenizer: The tokenizer which will be used for decoding.
"""
def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec):
blank_id = tokenizer.tokenizer.vocab_size # RNNT or TDT models.
# multi-blank RNNTs
if hasattr(decoding_cfg, 'model_type') and decoding_cfg.model_type == 'multiblank':
blank_id = tokenizer.tokenizer.vocab_size + joint.num_extra_outputs
self.tokenizer = tokenizer
super(RNNTBPEDecoding, self).__init__(
decoding_cfg=decoding_cfg, decoder=decoder, joint=joint, blank_id=blank_id
)
if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer):
self.decoding.set_decoding_type('subword')
def _aggregate_token_confidence(self, hypothesis: Hypothesis) -> List[float]:
"""
Implemented by subclass in order to reduce token confidence to a word-level confidence.
**Note**: Only supports Sentencepiece based tokenizers!
Args:
hypothesis: Hypothesis
Returns:
A list of word-level confidence scores.
"""
return self._aggregate_token_confidence_subwords_sentencepiece(
hypothesis.words, hypothesis.token_confidence, hypothesis.y_sequence
)
def decode_tokens_to_str(self, tokens: List[int]) -> str:
"""
Implemented by subclass in order to decoder a token list into a string.
Args:
tokens: List of int representing the token ids.
Returns:
A decoded string.
"""
hypothesis = self.tokenizer.ids_to_text(tokens)
return hypothesis
def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]:
"""
Implemented by subclass in order to decode a token id list into a token list.
A token list is the string representation of each token id.
Args:
tokens: List of int representing the token ids.
Returns:
A list of decoded tokens.
"""
token_list = self.tokenizer.ids_to_tokens(tokens)
return token_list
def decode_tokens_to_lang(self, tokens: List[int]) -> str:
"""
Compute the most likely language ID (LID) string given the tokens.
Args:
tokens: List of int representing the token ids.
Returns:
A decoded LID string.
"""
lang = self.tokenizer.ids_to_lang(tokens)
return lang
def decode_ids_to_langs(self, tokens: List[int]) -> List[str]:
"""
Decode a token id list into language ID (LID) list.
Args:
tokens: List of int representing the token ids.
Returns:
A list of decoded LIDS.
"""
lang_list = self.tokenizer.ids_to_text_and_langs(tokens)
return lang_list
def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hypothesis, NBestHypotheses]]:
"""
Decode a list of hypotheses into a list of strings.
Overrides the super() method optionally adding lang information
Args:
hypotheses_list: List of Hypothesis.
Returns:
A list of strings.
"""
hypotheses = super().decode_hypothesis(hypotheses_list)
if self.compute_langs:
if isinstance(self.tokenizer, AggregateTokenizer):
for ind in range(len(hypotheses_list)):
# Extract the integer encoded hypothesis
prediction = hypotheses_list[ind].y_sequence
if type(prediction) != list:
prediction = prediction.tolist()
# RNN-T sample level is already preprocessed by implicit RNNT decoding
# Simply remove any blank tokens
prediction = [p for p in prediction if p != self.blank_id]
hypotheses[ind].langs = self.decode_tokens_to_lang(prediction)
hypotheses[ind].langs_chars = self.decode_ids_to_langs(prediction)
else:
logging.warning(
"Ignoring request for lang output in hypotheses since the model does not use an aggregate tokenizer"
)
return hypotheses
class RNNTBPEWER(Metric):
"""
This metric computes numerator and denominator for Overall Word Error Rate (WER) between prediction and reference texts.
When doing distributed training/evaluation the result of res=WER(predictions, targets, target_lengths) calls
will be all-reduced between all workers using SUM operations.
Here contains two numbers res=[wer_numerator, wer_denominator]. WER=wer_numerator/wer_denominator.
If used with PytorchLightning LightningModule, include wer_numerator and wer_denominators inside validation_step results.
Then aggregate (sum) then at the end of validation epoch to correctly compute validation WER.
Example:
def validation_step(self, batch, batch_idx):
...
wer_num, wer_denom = self.__wer(predictions, transcript, transcript_len)
self.val_outputs = {'val_loss': loss_value, 'val_wer_num': wer_num, 'val_wer_denom': wer_denom}
return self.val_outputs
def on_validation_epoch_end(self):
...
wer_num = torch.stack([x['val_wer_num'] for x in self.val_outputs]).sum()
wer_denom = torch.stack([x['val_wer_denom'] for x in self.val_outputs]).sum()
tensorboard_logs = {'validation_loss': val_loss_mean, 'validation_avg_wer': wer_num / wer_denom}
self.val_outputs.clear() # free memory
return {'val_loss': val_loss_mean, 'log': tensorboard_logs}
Args:
decoding: RNNTBPEDecoding object that will perform autoregressive decoding of the RNNT model.
batch_dim_index: Index of the batch dimension.
use_cer: Whether to use Character Error Rate isntead of Word Error Rate.
log_prediction: Whether to log a single decoded sample per call.
Returns:
res: a tuple of 3 zero dimensional float32 ``torch.Tensor` objects: a WER score, a sum of Levenstein's
distances for all prediction - reference pairs, total number of words in all references.
"""
full_state_update = True
def __init__(
self,
decoding: RNNTBPEDecoding,
batch_dim_index=0,
use_cer: bool = False,
log_prediction: bool = True,
dist_sync_on_step=False,
):
super(RNNTBPEWER, self).__init__(dist_sync_on_step=dist_sync_on_step)
self.decoding = decoding
self.batch_dim_index = batch_dim_index
self.use_cer = use_cer
self.log_prediction = log_prediction
self.blank_id = self.decoding.blank_id
self.tokenizer = self.decoding.tokenizer
self.add_state("scores", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False)
self.add_state("words", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False)
def update(
self,
encoder_output: torch.Tensor,
encoded_lengths: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
words = 0
scores = 0
references = []
with torch.no_grad():
# prediction_cpu_tensor = tensors[0].long().cpu()
targets_cpu_tensor = targets.long().cpu()
targets_cpu_tensor = move_dimension_to_the_front(targets_cpu_tensor, self.batch_dim_index)
tgt_lenths_cpu_tensor = target_lengths.long().cpu()
# iterate over batch
for ind in range(targets_cpu_tensor.shape[0]):
tgt_len = tgt_lenths_cpu_tensor[ind].item()
target = targets_cpu_tensor[ind][:tgt_len].numpy().tolist()
reference = self.decoding.decode_tokens_to_str(target)
references.append(reference)
hypotheses, _ = self.decoding.rnnt_decoder_predictions_tensor(encoder_output, encoded_lengths)
if self.log_prediction:
logging.info(f"\n")
logging.info(f"reference :{references[0]}")
logging.info(f"predicted :{hypotheses[0]}")
for h, r in zip(hypotheses, references):
if self.use_cer:
h_list = list(h)
r_list = list(r)
else:
h_list = h.split()
r_list = r.split()
words += len(r_list)
# Compute Levenshtein's distance
scores += editdistance.eval(h_list, r_list)
del hypotheses
self.scores += torch.tensor(scores, device=self.scores.device, dtype=self.scores.dtype)
self.words += torch.tensor(words, device=self.words.device, dtype=self.words.dtype)
# return torch.tensor([scores, words]).to(predictions.device)
def compute(self):
wer = self.scores.float() / self.words
return wer, self.scores.detach(), self.words.detach()
@dataclass
class RNNTBPEDecodingConfig(RNNTDecodingConfig):
pass