Skip to content

Commit

Permalink
Prompt formatter API and canary transcribe tensor input support (#9206)
Browse files Browse the repository at this point in the history
* Apply CanaryPromptFormatter in dataset/inference

Signed-off-by: Piotr Żelasko <[email protected]>

* Working inference with CanaryPromptFormatter

Signed-off-by: Piotr Żelasko <[email protected]>

* Minimum working example of Canary.transcribe() with tensors

Signed-off-by: Piotr Żelasko <[email protected]>

* training fix

Signed-off-by: Piotr Żelasko <[email protected]>

* Update to the new 'chat' based prompt formatting API

Signed-off-by: Piotr Żelasko <[email protected]>

* Prompt formatters for popular models and partial unit test coverage

Signed-off-by: Piotr Żelasko <[email protected]>

* Updated documentation

Signed-off-by: Piotr Żelasko <[email protected]>

* Improved test coverage + proper preamble support

Signed-off-by: Piotr Żelasko <[email protected]>

* Fix usage of PromptFormatter for MT-AED class + fix tokenization/formatting issues

Signed-off-by: Piotr Żelasko <[email protected]>

* Move some canary hacks to canary prompt formatter, improve validation, add tests for aggtok

Signed-off-by: Piotr Żelasko <[email protected]>

* aed_model.transcribe(**slots) support, rename all slots to lowercase and drop pipes everywhere except template definition.

Signed-off-by: Piotr Żelasko <[email protected]>

* truly generic version

Signed-off-by: Piotr Żelasko <[email protected]>

* making transcribe_speech.py work prompt slots + syntactic sugar

Signed-off-by: Piotr Żelasko <[email protected]>

* update streaming_utils.py

Signed-off-by: Piotr Żelasko <[email protected]>

* fix

Signed-off-by: Piotr Żelasko <[email protected]>

* code review: partial

Signed-off-by: Piotr Żelasko <[email protected]>

* Accept multi-turn, single-turn, and legacy prompt format in transcribe() and transcribe_speech.py

Signed-off-by: Piotr Żelasko <[email protected]>

* Address code reviews

Signed-off-by: Piotr Żelasko <[email protected]>

* Add support for SPE special tokens bos/eos in prompt templates and ensure Llama2 format gives identical results with the reference implementation

Signed-off-by: Piotr Żelasko <[email protected]>

* Fix tests and add llama2 prompt formatter tests

Signed-off-by: Piotr Żelasko <[email protected]>

* Fix tests

Signed-off-by: Piotr Żelasko <[email protected]>

---------

Signed-off-by: Piotr Żelasko <[email protected]>
Signed-off-by: Jan Lasek <[email protected]>
  • Loading branch information
pzelasko authored and janekl committed Jun 12, 2024
1 parent 234b53f commit cbe51c4
Show file tree
Hide file tree
Showing 26 changed files with 1,382 additions and 211 deletions.
13 changes: 12 additions & 1 deletion examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import glob
import json
import os
from dataclasses import dataclass, is_dataclass
from dataclasses import dataclass, field, is_dataclass
from tempfile import NamedTemporaryFile
from typing import List, Optional, Union

Expand All @@ -25,6 +25,7 @@
from omegaconf import OmegaConf, open_dict

from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel, EncDecMultiTaskModel
from nemo.collections.asr.models.aed_multitask_models import parse_multitask_prompt
from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig
from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig
Expand Down Expand Up @@ -169,6 +170,14 @@ class TranscriptionConfig:

# Decoding strategy for AED models
multitask_decoding: MultiTaskDecodingConfig = MultiTaskDecodingConfig()
# Prompt slots for prompted models, e.g. Canary-1B. Examples of acceptable prompt inputs:
# Implicit single-turn assuming default role='user' (works with Canary-1B)
# +prompt.source_lang=en +prompt.target_lang=es +prompt.task=asr +prompt.pnc=yes
# Explicit single-turn prompt:
# +prompt.role=user +prompt.slots.source_lang=en +prompt.slots.target_lang=es +prompt.slots.task=s2t_translation +prompt.slots.pnc=yes
# Explicit multi-turn prompt:
# +prompt.turns='[{role:user,slots:{source_lang:en,target_lang:es,task:asr,pnc:yes}}]'
prompt: dict = field(default_factory=dict)

# decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models
decoder_type: Optional[str] = None
Expand Down Expand Up @@ -411,6 +420,8 @@ def autocast(dtype=None):
override_cfg.augmentor = augmentor
override_cfg.text_field = cfg.gt_text_attr_name
override_cfg.lang_field = cfg.gt_lang_attr_name
if hasattr(override_cfg, "prompt"):
override_cfg.prompt = parse_multitask_prompt(OmegaConf.to_container(cfg.prompt))
transcriptions = asr_model.transcribe(
audio=filepaths,
override_config=override_cfg,
Expand Down
45 changes: 24 additions & 21 deletions nemo/collections/asr/data/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def _speech_collate_fn(batch, pad_id):
has_audio = audio_lengths[0] is not None
if has_audio:
max_audio_len = max(audio_lengths).item()
max_tokens_len = max(tokens_lengths).item()
has_tokens = tokens_lengths[0] is not None
if has_tokens:
max_tokens_len = max(tokens_lengths).item()

audio_signal, tokens = [], []
for b in batch:
Expand All @@ -89,19 +91,24 @@ def _speech_collate_fn(batch, pad_id):
pad = (0, max_audio_len - sig_len)
sig = torch.nn.functional.pad(sig, pad)
audio_signal.append(sig)
tokens_i_len = tokens_i_len.item()
if tokens_i_len < max_tokens_len:
pad = (0, max_tokens_len - tokens_i_len)
tokens_i = torch.nn.functional.pad(tokens_i, pad, value=pad_id)
tokens.append(tokens_i)
if has_tokens:
tokens_i_len = tokens_i_len.item()
if tokens_i_len < max_tokens_len:
pad = (0, max_tokens_len - tokens_i_len)
tokens_i = torch.nn.functional.pad(tokens_i, pad, value=pad_id)
tokens.append(tokens_i)

if has_audio:
audio_signal = torch.stack(audio_signal)
audio_lengths = torch.stack(audio_lengths)
else:
audio_signal, audio_lengths = None, None
tokens = torch.stack(tokens)
tokens_lengths = torch.stack(tokens_lengths)
if has_tokens:
tokens = torch.stack(tokens)
tokens_lengths = torch.stack(tokens_lengths)
else:
tokens = None
tokens_lengths = None
if sample_ids is None:
return audio_signal, audio_lengths, tokens, tokens_lengths
else:
Expand Down Expand Up @@ -256,8 +263,7 @@ def cache_datastore_manifests(
if num_datastore_manifests > 0:
# Local utility function
def cache_data(manifest_filepaths, cache_audio, num_workers, max_num_workers):
"""Cache manifests and audio data from object store.
"""
"""Cache manifests and audio data from object store."""
# Determine the number of workers to use
if num_workers is None:
num_workers = os.cpu_count() - 1
Expand Down Expand Up @@ -421,8 +427,7 @@ class _AudioTextDataset(Dataset):

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -546,8 +551,7 @@ class AudioToCharDataset(_AudioTextDataset):

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -640,8 +644,7 @@ class AudioToBPEDataset(_AudioTextDataset):

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -910,8 +913,7 @@ def __next__(self):
return TarredAudioFilter(self.manifest_processor.collection)

def _loop_offsets(self, iterator):
"""This function is used to iterate through utterances with different offsets for each file.
"""
"""This function is used to iterate through utterances with different offsets for each file."""

class TarredAudioLoopOffsets:
def __init__(self, collection):
Expand Down Expand Up @@ -944,8 +946,7 @@ def _collate_fn(self, batch):
return _speech_collate_fn(batch, self.pad_id)

def _build_sample(self, tup):
"""Builds the training sample by combining the data from the WebDataset with the manifest info.
"""
"""Builds the training sample by combining the data from the WebDataset with the manifest info."""
audio_bytes, audio_filename, offset_id = tup

# Grab manifest entry from self.manifest_preprocessor.collection
Expand Down Expand Up @@ -1316,7 +1317,9 @@ class BucketingDataset(IterableDataset):
"""

def __init__(
self, dataset: IterableDataset, bucketing_batch_size: int,
self,
dataset: IterableDataset,
bucketing_batch_size: int,
):
self.wrapped_dataset = dataset
self.bucketing_batch_size = bucketing_batch_size
Expand Down
158 changes: 49 additions & 109 deletions nemo/collections/asr/data/audio_to_text_lhotse_prompted.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
# limitations under the License.
from typing import Callable, Sequence

import omegaconf
import torch.utils.data
from lhotse import CutSet
from lhotse.cut import MixedCut, MonoCut
from lhotse.dataset import AudioSamples
from lhotse.dataset.collation import collate_vectors

from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper
from nemo.collections.common.prompts.canary import CanaryPromptFormatter
from nemo.collections.common.tokenizers import CanaryTokenizer, TokenizerSpec
from nemo.collections.common.tokenizers.canary_tokenizer import CANARY_SPECIAL_TOKENIZER


class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -57,21 +58,21 @@ def __init__(
def __getitem__(self, cuts: CutSet) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
audio, audio_lens, cuts = self.load_audio(cuts)

tokens, prompt_tokens = self.prompt_format_fn(cuts, self.tokenizer, inference=self.inference)
prompts_with_answers, prompts = self.prompt_format_fn(cuts, self.tokenizer, inference=self.inference)

tokens = [torch.as_tensor(t) for t in tokens]
token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long)
tokens = collate_vectors(tokens, padding_value=self.padding_value)
prompts_with_answers = [torch.as_tensor(t) for t in prompts_with_answers]
prompts_with_answers_lens = torch.tensor([t.size(0) for t in prompts_with_answers], dtype=torch.long)
prompts_with_answers = collate_vectors(prompts_with_answers, padding_value=self.padding_value)

if self.inference:
prompt_tokens = [torch.as_tensor(t) for t in prompt_tokens]
prompt_token_lens = torch.tensor([t.size(0) for t in prompt_tokens], dtype=torch.long)
prompt_tokens = collate_vectors(prompt_tokens, padding_value=self.padding_value)
prompts = [torch.as_tensor(t) for t in prompts]
prompts_lens = torch.tensor([t.size(0) for t in prompts], dtype=torch.long)
prompts = collate_vectors(prompts, padding_value=self.padding_value)
else:
prompt_tokens = None
prompt_token_lens = None
prompts = None
prompts_lens = None

return audio, audio_lens, tokens, token_lens, prompt_tokens, prompt_token_lens
return audio, audio_lens, prompts_with_answers, prompts_with_answers_lens, prompts, prompts_lens


# Mapping from a string name to a known prompt formatter function.
Expand Down Expand Up @@ -105,7 +106,9 @@ def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper, bool]


@registered_prompt_format_fn
def canary(cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False) -> Sequence[Sequence[int]]:
def canary(
cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""
Prepend and append control tokens to the token sequence as per Canary format.
Expand All @@ -132,116 +135,53 @@ def canary(cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False) -
assert isinstance(
tokenizer._tokenizer, CanaryTokenizer
), "To use 'canary' prompt format, you must use the CanaryTokenizer."
tokenizer = tokenizer._tokenizer
formatter = CanaryPromptFormatter(tokenizer._tokenizer)

tokens, prompts = [], []
prompts_with_answers, prompts = [], []
for cut in cuts:
if isinstance(cut, MixedCut):
cut = cut._first_non_padding_cut
assert isinstance(cut, MonoCut), "Expected MonoCut."
if not isinstance(cut, MonoCut):
raise TypeError(
f"Expected input audio to have a single channel (required MonoCut/MixedCut, but we received: {cut=})"
)

# first, validate the utterance
missing_keys = [k for k in ("source_lang", "target_lang", "taskname", "pnc") if k not in cut.custom]
expected_slots = set(formatter.get_slots("user"))
missing_keys = expected_slots - set(cut.custom)
if "task" in missing_keys and "taskname" in cut.custom:
# Compatibility with "old" Canary manifest format.
# For compatbility with inference options, this slot is now called "task".
cut.custom["task"] = cut.custom["taskname"]
missing_keys.remove("task")
if missing_keys:
raise RuntimeError(
f"We found cut with ID {cut.id} that is missing the following keys: {missing_keys}"
f"Please ensure that every utterance in the input manifests contains these keys."
)

# Actual tokenization. If a cut has multiple supervisions, we'll stitch their tokenized texts together.
texts = [sup.text for sup in cut.supervisions]
langs = [sup.language for sup in cut.supervisions]
taskname = cut.custom['taskname']
pnc = cut.custom['pnc']
source_lang = cut.custom['source_lang']
target_lang = cut.custom['target_lang']

tokens.append(canary_prompt(tokenizer, texts, langs, source_lang, target_lang, taskname, pnc))
if inference:
prompts.append(canary_prompt(tokenizer, None, None, source_lang, target_lang, taskname, pnc))
return tokens, prompts


def canary_prompt(
tokenizer: CanaryTokenizer,
text: str | list[str] | None,
language: str | list[str] | None,
source_language: str,
target_language: str,
taskname: str,
pnc: str,
) -> list[int]:
if isinstance(text, str):
text = [text]
if isinstance(language, str):
language = [language]

if text is not None:
try:
tokens = sum((tokenizer.text_to_ids(text_, lang_) for text_, lang_ in zip(text, language)), start=[])
except omegaconf.errors.KeyValidationError as e:
raise ProbablyIncorrectLanguageKeyError(
"We couldn't select the right tokenizer, which could be due to issues with reading "
"the language from the manifest. "
"If you're training, try setting lang_field='' to a different value (probably 'target_lang' or 'lang'). "
"If you're using model.transcribe() directly, please use override_config kwarg to set this. "
"If you're using transcribe_speech.py, use option gt_lang_attr_name='...' "
) from e
else:
tokens = None # create prompt for inference

# bos
prompted_tokens = [tokenizer.bos_id]

if tokens is not None and len(tokens) == 0:
# no speech token
prompted_tokens.append(tokenizer.nospeech_id)
else:
# first, validate the utterance
if source_language is None or target_language is None or taskname is None or pnc is None:
raise RuntimeError(
f"Missing keys provided to prompt: "
f"source_langauge={source_language},\n"
f"target_language={target_language},\n"
f"taskname={taskname},\n"
f"pnc={pnc}\n"
f"Please ensure that every utterance in the input manifests contains these keys."
)

# src_lang_id/no_speech
src_lang_id = tokenizer.spl_token_to_id(source_language)
prompted_tokens.append(src_lang_id)

# task
task = taskname
if task == 'asr' or task == "transcribe":
prompted_tokens.append(tokenizer.spl_token_to_id("transcribe"))
elif task == 's2t_translation' or task == 'ast' or task == "translate":
prompted_tokens.append(tokenizer.spl_token_to_id("translate"))
else:
raise ValueError(f"Unknown task: {task}")

# tgt_lang_id
tgt_lang_id = tokenizer.spl_token_to_id(target_language)
prompted_tokens.append(tgt_lang_id)

# PnC
pnc = f"{pnc}".lower().strip() # to account for bool or str
if pnc in {'yes', 'true'}:
prompted_tokens.append(tokenizer.spl_token_to_id("pnc"))
elif pnc in {'no', 'false'}:
prompted_tokens.append(tokenizer.spl_token_to_id("nopnc"))
else:
raise ValueError(f"Unknown value for key 'pnc': {pnc}")

# text (only in training)
if tokens is not None:
prompted_tokens.extend(tokens)
encoded = formatter.encode_dialog(
turns=[
dict(
role="user",
slots={
**{slot: cut.custom[slot] for slot in expected_slots},
formatter.PROMPT_LANGUAGE_SLOT: CANARY_SPECIAL_TOKENIZER,
},
),
dict(
role="assistant",
slots={
"text": ' '.join(s.text for s in cut.supervisions),
formatter.PROMPT_LANGUAGE_SLOT: cut.custom["target_lang"],
},
),
]
)
prompts_with_answers.append(encoded["input_ids"])
prompts.append(encoded["context_ids"])

# eos (only in training)
if tokens is not None:
prompted_tokens.append(tokenizer.eos_id)
return prompted_tokens
return prompts_with_answers, prompts


class ProbablyIncorrectLanguageKeyError(RuntimeError):
Expand Down
Loading

0 comments on commit cbe51c4

Please sign in to comment.