Skip to content

Commit

Permalink
fix partial transcribe (NVIDIA#7284)
Browse files Browse the repository at this point in the history
* fix partial transcribe

Signed-off-by: stevehuang52 <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor

Signed-off-by: stevehuang52 <[email protected]>

---------

Signed-off-by: stevehuang52 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: dorotat <[email protected]>
  • Loading branch information
2 people authored and dorotat-nv committed Aug 24, 2023
1 parent a0bb04e commit 769b740
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 37 deletions.
10 changes: 5 additions & 5 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ class TranscriptionConfig:
# if True, will also skip writing anything to the output file
return_transcriptions: bool = False

# Set to False to return text instead of hypotheses from the transcribe function, so as to save memory
return_hypotheses: bool = True


@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis]]:
Expand Down Expand Up @@ -228,9 +231,6 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
asr_model.set_trainer(trainer)
asr_model = asr_model.eval()

# collect additional transcription information
return_hypotheses = True

# we will adjust this flag if the model does not support it
compute_timestamps = cfg.compute_timestamps
compute_langs = cfg.compute_langs
Expand Down Expand Up @@ -327,7 +327,7 @@ def autocast():
path2manifest=cfg.dataset_manifest,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
return_hypotheses=return_hypotheses,
return_hypotheses=cfg.return_hypotheses,
channel_selector=cfg.channel_selector,
augmentor=augmentor,
decoder_type=cfg.decoder_type,
Expand All @@ -337,7 +337,7 @@ def autocast():
paths2audio_files=filepaths,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
return_hypotheses=return_hypotheses,
return_hypotheses=cfg.return_hypotheses,
channel_selector=cfg.channel_selector,
augmentor=augmentor,
)
Expand Down
73 changes: 41 additions & 32 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def normalize_timestamp_output(timestamps: dict):


def write_transcription(
transcriptions: Union[List[rnnt_utils.Hypothesis], List[List[rnnt_utils.Hypothesis]]],
transcriptions: Union[List[rnnt_utils.Hypothesis], List[List[rnnt_utils.Hypothesis]], List[str]],
cfg: DictConfig,
model_name: str,
filepaths: List[str] = None,
Expand All @@ -290,7 +290,11 @@ def write_transcription(
else:
pred_text_attr_name = 'pred_text'

if isinstance(transcriptions[0], rnnt_utils.Hypothesis): # List[rnnt_utils.Hypothesis]
return_hypotheses = True
if isinstance(transcriptions[0], str): # List[str]:
best_hyps = transcriptions
return_hypotheses = False
elif isinstance(transcriptions[0], rnnt_utils.Hypothesis): # List[rnnt_utils.Hypothesis]
best_hyps = transcriptions
assert cfg.decoding.beam.return_best_hypothesis, "Works only with return_best_hypothesis=true"
elif isinstance(transcriptions[0], list) and isinstance(
Expand All @@ -309,31 +313,14 @@ def write_transcription(

with open(cfg.output_filename, 'w', encoding='utf-8', newline='\n') as f:
if cfg.audio_dir is not None:
for idx, transcription in enumerate(best_hyps): # type: rnnt_utils.Hypothesis
item = {'audio_filepath': filepaths[idx], pred_text_attr_name: transcription.text}

if compute_timestamps:
timestamps = transcription.timestep
if timestamps is not None and isinstance(timestamps, dict):
timestamps.pop('timestep', None) # Pytorch tensor calculating index of each token, not needed.
for key in timestamps.keys():
values = normalize_timestamp_output(timestamps[key])
item[f'timestamps_{key}'] = values

if compute_langs:
item['pred_lang'] = transcription.langs
item['pred_lang_chars'] = transcription.langs_chars
if not cfg.decoding.beam.return_best_hypothesis:
item['beams'] = beams[idx]
f.write(json.dumps(item) + "\n")
else:
with open(cfg.dataset_manifest, 'r', encoding='utf-8') as fr:
for idx, line in enumerate(fr):
item = json.loads(line)
item[pred_text_attr_name] = best_hyps[idx].text
for idx, transcription in enumerate(best_hyps): # type: rnnt_utils.Hypothesis or str
if not return_hypotheses: # transcription is str
item = {'audio_filepath': filepaths[idx], pred_text_attr_name: transcription}
else: # transcription is Hypothesis
item = {'audio_filepath': filepaths[idx], pred_text_attr_name: transcription.text}

if compute_timestamps:
timestamps = best_hyps[idx].timestep
timestamps = transcription.timestep
if timestamps is not None and isinstance(timestamps, dict):
timestamps.pop(
'timestep', None
Expand All @@ -343,11 +330,36 @@ def write_transcription(
item[f'timestamps_{key}'] = values

if compute_langs:
item['pred_lang'] = best_hyps[idx].langs
item['pred_lang_chars'] = best_hyps[idx].langs_chars

item['pred_lang'] = transcription.langs
item['pred_lang_chars'] = transcription.langs_chars
if not cfg.decoding.beam.return_best_hypothesis:
item['beams'] = beams[idx]
f.write(json.dumps(item) + "\n")
else:
with open(cfg.dataset_manifest, 'r', encoding='utf-8') as fr:
for idx, line in enumerate(fr):
item = json.loads(line)
if not return_hypotheses: # transcription is str
item[pred_text_attr_name] = best_hyps[idx]
else: # transcription is Hypothesis
item[pred_text_attr_name] = best_hyps[idx].text

if compute_timestamps:
timestamps = best_hyps[idx].timestep
if timestamps is not None and isinstance(timestamps, dict):
timestamps.pop(
'timestep', None
) # Pytorch tensor calculating index of each token, not needed.
for key in timestamps.keys():
values = normalize_timestamp_output(timestamps[key])
item[f'timestamps_{key}'] = values

if compute_langs:
item['pred_lang'] = best_hyps[idx].langs
item['pred_lang_chars'] = best_hyps[idx].langs_chars

if not cfg.decoding.beam.return_best_hypothesis:
item['beams'] = beams[idx]
f.write(json.dumps(item) + "\n")

return cfg.output_filename, pred_text_attr_name
Expand Down Expand Up @@ -434,10 +446,7 @@ def transcribe_partial_audio(
lg = logits[idx][: logits_len[idx]]
hypotheses.append(lg)
else:
current_hypotheses, all_hyp = decode_function(logits, logits_len, return_hypotheses=return_hypotheses,)

if isinstance(current_hypotheses, tuple) and len(current_hypotheses) == 2:
current_hypotheses = current_hypotheses[0]
current_hypotheses, _ = decode_function(logits, logits_len, return_hypotheses=return_hypotheses,)

if return_hypotheses:
# dump log probs per file
Expand Down

0 comments on commit 769b740

Please sign in to comment.