Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,15 @@ class Wav2Vec2DecoderWithLMOutput(ModelOutput):
Args:
text (list of `str`):
Decoded logits in text from. Usually the speech transcription.
logit_score (list of `float`):
Total logit score of the beam associated with produced text.
lm_score (list of `float`):
Fused lm_score of the beam associated with produced text.
"""

text: Union[List[str], str]
logit_score: Union[List[float], float] = None
lm_score: Union[List[float], float] = None


class Wav2Vec2ProcessorWithLM(ProcessorMixin):
Expand Down Expand Up @@ -283,7 +289,8 @@ def batch_decode(
)

# create multiprocessing pool and list numpy arrays
logits_list = [array for array in logits]
# filter out logits padding
logits_list = [array[(array != -100.0).all(axis=-1)] for array in logits]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

pool = get_context("fork").Pool(num_processes)

# pyctcdecode
Expand All @@ -300,11 +307,14 @@ def batch_decode(
# clone multi-processing pool
pool.close()

# extract text
batch_texts = [d[0][0] for d in decoded_beams]

# extract text and scores
batch_texts, logit_scores, lm_scores = [], [], []
for d in decoded_beams:
batch_texts.append(d[0][0])
logit_scores.append(d[0][-2])
lm_scores.append(d[0][-1])
# more output features will be added in the future
return Wav2Vec2DecoderWithLMOutput(text=batch_texts)
return Wav2Vec2DecoderWithLMOutput(text=batch_texts, logit_score=logit_scores, lm_score=lm_scores)

def decode(
self,
Expand Down Expand Up @@ -379,7 +389,9 @@ def decode(
)

# more output features will be added in the future
return Wav2Vec2DecoderWithLMOutput(text=decoded_beams[0][0])
return Wav2Vec2DecoderWithLMOutput(
text=decoded_beams[0][0], logit_score=decoded_beams[0][-2], lm_score=decoded_beams[0][-1]
)

@contextmanager
def as_target_processor(self):
Expand Down
25 changes: 17 additions & 8 deletions tests/test_processor_wav2vec2_with_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,14 @@ def test_decoder(self):

logits = self._get_dummy_logits(shape=(10, 16), seed=13)

decoded_processor = processor.decode(logits).text
decoded_processor = processor.decode(logits)

decoded_decoder = decoder.decode_beams(logits)[0][0]
decoded_decoder = decoder.decode_beams(logits)[0]

self.assertEqual(decoded_decoder, decoded_processor)
self.assertEqual("</s> <s> </s>", decoded_processor)
self.assertEqual(decoded_decoder[0], decoded_processor.text)
self.assertEqual("</s> <s> </s>", decoded_processor.text)
self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score)
self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score)

def test_decoder_batch(self):
feature_extractor = self.get_feature_extractor()
Expand All @@ -193,15 +195,22 @@ def test_decoder_batch(self):

logits = self._get_dummy_logits()

decoded_processor = processor.batch_decode(logits).text
decoded_processor = processor.batch_decode(logits)

logits_list = [array for array in logits]
pool = get_context("fork").Pool()
decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(pool, logits_list)]
decoded_beams = decoder.decode_beams_batch(pool, logits_list)
texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], []
for beams in decoded_beams:
texts_decoder.append(beams[0][0])
logit_scores_decoder.append(beams[0][-2])
lm_scores_decoder.append(beams[0][-1])
pool.close()

self.assertListEqual(decoded_decoder, decoded_processor)
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor)
self.assertListEqual(texts_decoder, decoded_processor.text)
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor.text)
self.assertListEqual(logit_scores_decoder, decoded_processor.logit_score)
self.assertListEqual(lm_scores_decoder, decoded_processor.lm_score)

def test_decoder_with_params(self):
feature_extractor = self.get_feature_extractor()
Expand Down