diff --git a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py index c31b209c1879..2dc88564bf86 100644 --- a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +++ b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py @@ -42,9 +42,15 @@ class Wav2Vec2DecoderWithLMOutput(ModelOutput): Args: text (list of `str`): Decoded logits in text from. Usually the speech transcription. + logit_score (list of `float`): + Total logit score of the beam associated with produced text. + lm_score (list of `float`): + Fused lm_score of the beam associated with produced text. """ text: Union[List[str], str] + logit_score: Union[List[float], float] = None + lm_score: Union[List[float], float] = None class Wav2Vec2ProcessorWithLM(ProcessorMixin): @@ -283,7 +289,8 @@ def batch_decode( ) # create multiprocessing pool and list numpy arrays - logits_list = [array for array in logits] + # filter out logits padding + logits_list = [array[(array != -100.0).all(axis=-1)] for array in logits] pool = get_context("fork").Pool(num_processes) # pyctcdecode @@ -300,11 +307,14 @@ def batch_decode( # clone multi-processing pool pool.close() - # extract text - batch_texts = [d[0][0] for d in decoded_beams] - + # extract text and scores + batch_texts, logit_scores, lm_scores = [], [], [] + for d in decoded_beams: + batch_texts.append(d[0][0]) + logit_scores.append(d[0][-2]) + lm_scores.append(d[0][-1]) # more output features will be added in the future - return Wav2Vec2DecoderWithLMOutput(text=batch_texts) + return Wav2Vec2DecoderWithLMOutput(text=batch_texts, logit_score=logit_scores, lm_score=lm_scores) def decode( self, @@ -379,7 +389,9 @@ def decode( ) # more output features will be added in the future - return Wav2Vec2DecoderWithLMOutput(text=decoded_beams[0][0]) + return Wav2Vec2DecoderWithLMOutput( + text=decoded_beams[0][0], logit_score=decoded_beams[0][-2], lm_score=decoded_beams[0][-1] + ) @contextmanager def as_target_processor(self): diff --git a/tests/test_processor_wav2vec2_with_lm.py b/tests/test_processor_wav2vec2_with_lm.py index f918a0894a47..800e4cbdc735 100644 --- a/tests/test_processor_wav2vec2_with_lm.py +++ b/tests/test_processor_wav2vec2_with_lm.py @@ -177,12 +177,14 @@ def test_decoder(self): logits = self._get_dummy_logits(shape=(10, 16), seed=13) - decoded_processor = processor.decode(logits).text + decoded_processor = processor.decode(logits) - decoded_decoder = decoder.decode_beams(logits)[0][0] + decoded_decoder = decoder.decode_beams(logits)[0] - self.assertEqual(decoded_decoder, decoded_processor) - self.assertEqual(" ", decoded_processor) + self.assertEqual(decoded_decoder[0], decoded_processor.text) + self.assertEqual(" ", decoded_processor.text) + self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score) + self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score) def test_decoder_batch(self): feature_extractor = self.get_feature_extractor() @@ -193,15 +195,22 @@ def test_decoder_batch(self): logits = self._get_dummy_logits() - decoded_processor = processor.batch_decode(logits).text + decoded_processor = processor.batch_decode(logits) logits_list = [array for array in logits] pool = get_context("fork").Pool() - decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(pool, logits_list)] + decoded_beams = decoder.decode_beams_batch(pool, logits_list) + texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], [] + for beams in decoded_beams: + texts_decoder.append(beams[0][0]) + logit_scores_decoder.append(beams[0][-2]) + lm_scores_decoder.append(beams[0][-1]) pool.close() - self.assertListEqual(decoded_decoder, decoded_processor) - self.assertListEqual([" ", " "], decoded_processor) + self.assertListEqual(texts_decoder, decoded_processor.text) + self.assertListEqual([" ", " "], decoded_processor.text) + self.assertListEqual(logit_scores_decoder, decoded_processor.logit_score) + self.assertListEqual(lm_scores_decoder, decoded_processor.lm_score) def test_decoder_with_params(self): feature_extractor = self.get_feature_extractor()