Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix character beam decoding algorithm with vocab index map #6140

Merged
merged 1 commit into from
Mar 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions nemo/collections/asr/parts/submodules/ctc_beam_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def __init__(self, blank_id: int, beam_size: int):
self.decoding_type = None
self.tokenizer = None

# Utility maps for vocabulary
self.vocab_index_map = None
self.index_vocab_map = None

# Internal variable, used to prevent double reduction of consecutive tokens (ctc collapse)
self.override_fold_consecutive_value = None

Expand All @@ -110,6 +114,8 @@ def set_vocabulary(self, vocab: List[str]):
Note that this vocabulary must NOT contain the "BLANK" token.
"""
self.vocab = vocab
self.vocab_index_map = {v: i for i, v in enumerate(vocab)}
self.index_vocab_map = {i: v for i, v in enumerate(vocab)}

def set_decoding_type(self, decoding_type: str):
"""
Expand Down Expand Up @@ -352,7 +358,8 @@ def default_beam_search(
if self.decoding_type == 'subword':
pred_token_ids = [ord(c) - self.token_offset for c in candidate[1]]
else:
pred_token_ids = candidate[1]
# Char models
pred_token_ids = [self.vocab_index_map[c] for c in candidate[1]]

# We preserve the token ids and the score for this hypothesis
hypothesis.y_sequence = pred_token_ids
Expand Down Expand Up @@ -444,7 +451,7 @@ def _pyctcdecode_beam_search(
raise ValueError("Vocab must be provided for character decoding. Use set_vocab().")

chars = list(candidate[0])
pred_token_ids = [self.vocab[c] for c in chars]
pred_token_ids = [self.vocab_index_map[c] for c in chars]

hypothesis.y_sequence = pred_token_ids
hypothesis.text = candidate[0] # text
Expand Down