From 061723e2b8151ae3de4e504d5c7db825944eca51 Mon Sep 17 00:00:00 2001 From: Arto Date: Sun, 30 Jan 2022 18:56:01 +0000 Subject: [PATCH 1/2] add scores to Wav2Vec2WithLMOutput --- .../processing_wav2vec2_with_lm.py | 24 ++++++++++++----- tests/test_processor_wav2vec2_with_lm.py | 27 ++++++++++++------- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py index 0c8ac8e09864..dd7e5dff4cf4 100644 --- a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +++ b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py @@ -42,9 +42,15 @@ class Wav2Vec2DecoderWithLMOutput(ModelOutput): Args: text (list of `str`): Decoded logits in text from. Usually the speech transcription. + logit_score (list of `float`): + Total logit score of the beam associated with produced text. + lm_score (list of `float`): + Fused lm_score of the beam associated with produced text. """ text: Union[List[str], str] + logit_score: Union[List[float], float] = None + lm_score: Union[List[float], float] = None class Wav2Vec2ProcessorWithLM: @@ -299,7 +305,8 @@ def batch_decode( hotword_weight = hotword_weight if hotword_weight is not None else DEFAULT_HOTWORD_WEIGHT # create multiprocessing pool and list numpy arrays - logits_list = [array for array in logits] + # filter out logits padding + logits_list = [array[(array != -100.0).all(axis=-1)] for array in logits] pool = get_context("fork").Pool(num_processes) # pyctcdecode @@ -316,11 +323,14 @@ def batch_decode( # clone multi-processing pool pool.close() - # extract text - batch_texts = [d[0][0] for d in decoded_beams] - + # extract text and scores + batch_texts, logit_scores, lm_scores = [], [], [] + for d in decoded_beams: + batch_texts.append(d[0][0]) + logit_scores.append(d[0][-2]) + lm_scores.append(d[0][-1]) # more output features will be added in the future - return Wav2Vec2DecoderWithLMOutput(text=batch_texts) + return Wav2Vec2DecoderWithLMOutput(text=batch_texts, logit_score=logit_scores, lm_score=lm_scores) def decode( self, @@ -378,7 +388,9 @@ def decode( ) # more output features will be added in the future - return Wav2Vec2DecoderWithLMOutput(text=decoded_beams[0][0]) + return Wav2Vec2DecoderWithLMOutput( + text=decoded_beams[0][0], logit_score=decoded_beams[0][-2], lm_score=decoded_beams[0][-1] + ) @contextmanager def as_target_processor(self): diff --git a/tests/test_processor_wav2vec2_with_lm.py b/tests/test_processor_wav2vec2_with_lm.py index 37c6ff01d9c5..9e25236a6cc1 100644 --- a/tests/test_processor_wav2vec2_with_lm.py +++ b/tests/test_processor_wav2vec2_with_lm.py @@ -177,12 +177,14 @@ def test_decoder(self): logits = self._get_dummy_logits(shape=(10, 16), seed=13) - decoded_processor = processor.decode(logits).text + decoded_processor = processor.decode(logits) - decoded_decoder = decoder.decode_beams(logits)[0][0] + decoded_decoder = decoder.decode_beams(logits)[0] - self.assertEqual(decoded_decoder, decoded_processor) - self.assertEqual(" ", decoded_processor) + self.assertEqual(decoded_decoder[0], decoded_processor.text) + self.assertEqual(" ", decoded_processor.text) + self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score) + self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score) def test_decoder_batch(self): feature_extractor = self.get_feature_extractor() @@ -193,13 +195,20 @@ def test_decoder_batch(self): logits = self._get_dummy_logits() - decoded_processor = processor.batch_decode(logits).text + decoded_processor = processor.batch_decode(logits) logits_list = [array for array in logits] - decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(Pool(), logits_list)] - - self.assertListEqual(decoded_decoder, decoded_processor) - self.assertListEqual([" ", " "], decoded_processor) + decoded_beams = decoder.decode_beams_batch(Pool(), logits_list) + texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], [] + for beams in decoded_beams: + texts_decoder.append(beams[0][0]) + logit_scores_decoder.append(beams[0][-2]) + lm_scores_decoder.append(beams[0][-1]) + + self.assertListEqual(texts_decoder, decoded_processor.text) + self.assertListEqual([" ", " "], decoded_processor.text) + self.assertListEqual(logit_scores_decoder, decoded_processor.logit_score) + self.assertListEqual(lm_scores_decoder, decoded_processor.lm_score) def test_decoder_with_params(self): feature_extractor = self.get_feature_extractor() From 71e1375be7eb18553df438e340acc268a52cedfa Mon Sep 17 00:00:00 2001 From: Arto Date: Sat, 12 Feb 2022 16:48:16 +0200 Subject: [PATCH 2/2] style fixup --- tests/test_processor_wav2vec2_with_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_processor_wav2vec2_with_lm.py b/tests/test_processor_wav2vec2_with_lm.py index 24076a85b9c2..800e4cbdc735 100644 --- a/tests/test_processor_wav2vec2_with_lm.py +++ b/tests/test_processor_wav2vec2_with_lm.py @@ -206,7 +206,7 @@ def test_decoder_batch(self): logit_scores_decoder.append(beams[0][-2]) lm_scores_decoder.append(beams[0][-1]) pool.close() - + self.assertListEqual(texts_decoder, decoded_processor.text) self.assertListEqual([" ", " "], decoded_processor.text) self.assertListEqual(logit_scores_decoder, decoded_processor.logit_score)