diff --git a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
index c31b209c1879..2dc88564bf86 100644
--- a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
+++ b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
@@ -42,9 +42,15 @@ class Wav2Vec2DecoderWithLMOutput(ModelOutput):
Args:
text (list of `str`):
Decoded logits in text from. Usually the speech transcription.
+ logit_score (list of `float`):
+ Total logit score of the beam associated with produced text.
+ lm_score (list of `float`):
+ Fused lm_score of the beam associated with produced text.
"""
text: Union[List[str], str]
+ logit_score: Union[List[float], float] = None
+ lm_score: Union[List[float], float] = None
class Wav2Vec2ProcessorWithLM(ProcessorMixin):
@@ -283,7 +289,8 @@ def batch_decode(
)
# create multiprocessing pool and list numpy arrays
- logits_list = [array for array in logits]
+ # filter out logits padding
+ logits_list = [array[(array != -100.0).all(axis=-1)] for array in logits]
pool = get_context("fork").Pool(num_processes)
# pyctcdecode
@@ -300,11 +307,14 @@ def batch_decode(
# clone multi-processing pool
pool.close()
- # extract text
- batch_texts = [d[0][0] for d in decoded_beams]
-
+ # extract text and scores
+ batch_texts, logit_scores, lm_scores = [], [], []
+ for d in decoded_beams:
+ batch_texts.append(d[0][0])
+ logit_scores.append(d[0][-2])
+ lm_scores.append(d[0][-1])
# more output features will be added in the future
- return Wav2Vec2DecoderWithLMOutput(text=batch_texts)
+ return Wav2Vec2DecoderWithLMOutput(text=batch_texts, logit_score=logit_scores, lm_score=lm_scores)
def decode(
self,
@@ -379,7 +389,9 @@ def decode(
)
# more output features will be added in the future
- return Wav2Vec2DecoderWithLMOutput(text=decoded_beams[0][0])
+ return Wav2Vec2DecoderWithLMOutput(
+ text=decoded_beams[0][0], logit_score=decoded_beams[0][-2], lm_score=decoded_beams[0][-1]
+ )
@contextmanager
def as_target_processor(self):
diff --git a/tests/test_processor_wav2vec2_with_lm.py b/tests/test_processor_wav2vec2_with_lm.py
index f918a0894a47..800e4cbdc735 100644
--- a/tests/test_processor_wav2vec2_with_lm.py
+++ b/tests/test_processor_wav2vec2_with_lm.py
@@ -177,12 +177,14 @@ def test_decoder(self):
logits = self._get_dummy_logits(shape=(10, 16), seed=13)
- decoded_processor = processor.decode(logits).text
+ decoded_processor = processor.decode(logits)
- decoded_decoder = decoder.decode_beams(logits)[0][0]
+ decoded_decoder = decoder.decode_beams(logits)[0]
- self.assertEqual(decoded_decoder, decoded_processor)
- self.assertEqual(" ", decoded_processor)
+ self.assertEqual(decoded_decoder[0], decoded_processor.text)
+ self.assertEqual(" ", decoded_processor.text)
+ self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score)
+ self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score)
def test_decoder_batch(self):
feature_extractor = self.get_feature_extractor()
@@ -193,15 +195,22 @@ def test_decoder_batch(self):
logits = self._get_dummy_logits()
- decoded_processor = processor.batch_decode(logits).text
+ decoded_processor = processor.batch_decode(logits)
logits_list = [array for array in logits]
pool = get_context("fork").Pool()
- decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(pool, logits_list)]
+ decoded_beams = decoder.decode_beams_batch(pool, logits_list)
+ texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], []
+ for beams in decoded_beams:
+ texts_decoder.append(beams[0][0])
+ logit_scores_decoder.append(beams[0][-2])
+ lm_scores_decoder.append(beams[0][-1])
pool.close()
- self.assertListEqual(decoded_decoder, decoded_processor)
- self.assertListEqual([" ", " "], decoded_processor)
+ self.assertListEqual(texts_decoder, decoded_processor.text)
+ self.assertListEqual([" ", " "], decoded_processor.text)
+ self.assertListEqual(logit_scores_decoder, decoded_processor.logit_score)
+ self.assertListEqual(lm_scores_decoder, decoded_processor.lm_score)
def test_decoder_with_params(self):
feature_extractor = self.get_feature_extractor()