Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt formatter API and canary transcribe tensor input support #9206

Merged
merged 22 commits into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
bfbacdc
Apply CanaryPromptFormatter in dataset/inference
pzelasko May 15, 2024
967776b
Working inference with CanaryPromptFormatter
pzelasko May 15, 2024
04fdba9
Minimum working example of Canary.transcribe() with tensors
pzelasko May 15, 2024
1f902ff
training fix
pzelasko May 15, 2024
e86362e
Update to the new 'chat' based prompt formatting API
pzelasko May 21, 2024
41f96f1
Prompt formatters for popular models and partial unit test coverage
pzelasko May 21, 2024
06ff96d
Updated documentation
pzelasko May 21, 2024
71b9191
Improved test coverage + proper preamble support
pzelasko May 22, 2024
5555a4c
Fix usage of PromptFormatter for MT-AED class + fix tokenization/form…
pzelasko May 22, 2024
2350356
Move some canary hacks to canary prompt formatter, improve validation…
pzelasko May 22, 2024
30713b8
aed_model.transcribe(**slots) support, rename all slots to lowercase …
pzelasko May 23, 2024
9334a88
truly generic version
pzelasko May 23, 2024
2f7cd7a
making transcribe_speech.py work prompt slots + syntactic sugar
pzelasko May 23, 2024
3a533ae
update streaming_utils.py
pzelasko May 23, 2024
d6f75f0
Merge branch 'main' into prompt-formatter-and-canary-tensor-dataset
pzelasko May 23, 2024
61f92d8
fix
pzelasko May 23, 2024
9fe28cb
code review: partial
pzelasko May 24, 2024
3f60244
Accept multi-turn, single-turn, and legacy prompt format in transcrib…
pzelasko May 29, 2024
9e13c2e
Address code reviews
pzelasko May 31, 2024
3f9453b
Add support for SPE special tokens bos/eos in prompt templates and en…
pzelasko May 31, 2024
55ac422
Fix tests and add llama2 prompt formatter tests
pzelasko May 31, 2024
43ec9ad
Fix tests
pzelasko May 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 24 additions & 21 deletions nemo/collections/asr/data/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def _speech_collate_fn(batch, pad_id):
has_audio = audio_lengths[0] is not None
if has_audio:
max_audio_len = max(audio_lengths).item()
max_tokens_len = max(tokens_lengths).item()
has_tokens = tokens_lengths[0] is not None
pzelasko marked this conversation as resolved.
Show resolved Hide resolved
if has_tokens:
max_tokens_len = max(tokens_lengths).item()

audio_signal, tokens = [], []
for b in batch:
Expand All @@ -89,19 +91,24 @@ def _speech_collate_fn(batch, pad_id):
pad = (0, max_audio_len - sig_len)
sig = torch.nn.functional.pad(sig, pad)
audio_signal.append(sig)
tokens_i_len = tokens_i_len.item()
if tokens_i_len < max_tokens_len:
pad = (0, max_tokens_len - tokens_i_len)
tokens_i = torch.nn.functional.pad(tokens_i, pad, value=pad_id)
tokens.append(tokens_i)
if has_tokens:
tokens_i_len = tokens_i_len.item()
if tokens_i_len < max_tokens_len:
pad = (0, max_tokens_len - tokens_i_len)
tokens_i = torch.nn.functional.pad(tokens_i, pad, value=pad_id)
tokens.append(tokens_i)

if has_audio:
audio_signal = torch.stack(audio_signal)
audio_lengths = torch.stack(audio_lengths)
else:
audio_signal, audio_lengths = None, None
tokens = torch.stack(tokens)
tokens_lengths = torch.stack(tokens_lengths)
if has_tokens:
tokens = torch.stack(tokens)
tokens_lengths = torch.stack(tokens_lengths)
else:
tokens = None
tokens_lengths = None
if sample_ids is None:
return audio_signal, audio_lengths, tokens, tokens_lengths
else:
Expand Down Expand Up @@ -256,8 +263,7 @@ def cache_datastore_manifests(
if num_datastore_manifests > 0:
# Local utility function
def cache_data(manifest_filepaths, cache_audio, num_workers, max_num_workers):
"""Cache manifests and audio data from object store.
"""
"""Cache manifests and audio data from object store."""
# Determine the number of workers to use
if num_workers is None:
num_workers = os.cpu_count() - 1
Expand Down Expand Up @@ -421,8 +427,7 @@ class _AudioTextDataset(Dataset):

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -546,8 +551,7 @@ class AudioToCharDataset(_AudioTextDataset):

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -640,8 +644,7 @@ class AudioToBPEDataset(_AudioTextDataset):

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -910,8 +913,7 @@ def __next__(self):
return TarredAudioFilter(self.manifest_processor.collection)

def _loop_offsets(self, iterator):
"""This function is used to iterate through utterances with different offsets for each file.
"""
"""This function is used to iterate through utterances with different offsets for each file."""

class TarredAudioLoopOffsets:
def __init__(self, collection):
Expand Down Expand Up @@ -944,8 +946,7 @@ def _collate_fn(self, batch):
return _speech_collate_fn(batch, self.pad_id)

def _build_sample(self, tup):
"""Builds the training sample by combining the data from the WebDataset with the manifest info.
"""
"""Builds the training sample by combining the data from the WebDataset with the manifest info."""
audio_bytes, audio_filename, offset_id = tup

# Grab manifest entry from self.manifest_preprocessor.collection
Expand Down Expand Up @@ -1316,7 +1317,9 @@ class BucketingDataset(IterableDataset):
"""

def __init__(
self, dataset: IterableDataset, bucketing_batch_size: int,
self,
dataset: IterableDataset,
bucketing_batch_size: int,
):
self.wrapped_dataset = dataset
self.bucketing_batch_size = bucketing_batch_size
Expand Down
121 changes: 27 additions & 94 deletions nemo/collections/asr/data/audio_to_text_lhotse_prompted.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from lhotse.dataset.collation import collate_vectors

from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper
from nemo.collections.common.prompts.canary import CanaryPromptFormatter
from nemo.collections.common.tokenizers import CanaryTokenizer, TokenizerSpec
from nemo.collections.common.tokenizers.canary_tokenizer import CANARY_SPECIAL_TOKENIZER


class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -132,7 +134,7 @@ def canary(cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False) -
assert isinstance(
tokenizer._tokenizer, CanaryTokenizer
), "To use 'canary' prompt format, you must use the CanaryTokenizer."
tokenizer = tokenizer._tokenizer
formatter = CanaryPromptFormatter(tokenizer._tokenizer)

tokens, prompts = [], []
for cut in cuts:
Expand All @@ -148,100 +150,31 @@ def canary(cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False) -
f"Please ensure that every utterance in the input manifests contains these keys."
)

# Actual tokenization. If a cut has multiple supervisions, we'll stitch their tokenized texts together.
texts = [sup.text for sup in cut.supervisions]
langs = [sup.language for sup in cut.supervisions]
taskname = cut.custom['taskname']
pnc = cut.custom['pnc']
source_lang = cut.custom['source_lang']
target_lang = cut.custom['target_lang']

tokens.append(canary_prompt(tokenizer, texts, langs, source_lang, target_lang, taskname, pnc))
if inference:
prompts.append(canary_prompt(tokenizer, None, None, source_lang, target_lang, taskname, pnc))
return tokens, prompts


def canary_prompt(
tokenizer: CanaryTokenizer,
text: str | list[str] | None,
language: str | list[str] | None,
source_language: str,
target_language: str,
taskname: str,
pnc: str,
) -> list[int]:
if isinstance(text, str):
text = [text]
if isinstance(language, str):
language = [language]

if text is not None:
try:
tokens = sum((tokenizer.text_to_ids(text_, lang_) for text_, lang_ in zip(text, language)), start=[])
except omegaconf.errors.KeyValidationError as e:
raise ProbablyIncorrectLanguageKeyError(
"We couldn't select the right tokenizer, which could be due to issues with reading "
"the language from the manifest. "
"If you're training, try setting lang_field='' to a different value (probably 'target_lang' or 'lang'). "
"If you're using model.transcribe() directly, please use override_config kwarg to set this. "
"If you're using transcribe_speech.py, use option gt_lang_attr_name='...' "
) from e
else:
tokens = None # create prompt for inference

# bos
prompted_tokens = [tokenizer.bos_id]

if tokens is not None and len(tokens) == 0:
# no speech token
prompted_tokens.append(tokenizer.nospeech_id)
else:
# first, validate the utterance
if source_language is None or target_language is None or taskname is None or pnc is None:
raise RuntimeError(
f"Missing keys provided to prompt: "
f"source_langauge={source_language},\n"
f"target_language={target_language},\n"
f"taskname={taskname},\n"
f"pnc={pnc}\n"
f"Please ensure that every utterance in the input manifests contains these keys."
)

# src_lang_id/no_speech
src_lang_id = tokenizer.spl_token_to_id(source_language)
prompted_tokens.append(src_lang_id)

# task
task = taskname
if task == 'asr' or task == "transcribe":
prompted_tokens.append(tokenizer.spl_token_to_id("transcribe"))
elif task == 's2t_translation' or task == 'ast' or task == "translate":
prompted_tokens.append(tokenizer.spl_token_to_id("translate"))
else:
raise ValueError(f"Unknown task: {task}")

# tgt_lang_id
tgt_lang_id = tokenizer.spl_token_to_id(target_language)
prompted_tokens.append(tgt_lang_id)

# PnC
pnc = f"{pnc}".lower().strip() # to account for bool or str
if pnc in {'yes', 'true'}:
prompted_tokens.append(tokenizer.spl_token_to_id("pnc"))
elif pnc in {'no', 'false'}:
prompted_tokens.append(tokenizer.spl_token_to_id("nopnc"))
else:
raise ValueError(f"Unknown value for key 'pnc': {pnc}")

# text (only in training)
if tokens is not None:
prompted_tokens.extend(tokens)
encoded = formatter.encode_dialog(
turns=[
dict(
role="user",
slots={
"|SOURCE_LANG|": cut.custom['source_lang'],
"|TARGET_LANG|": cut.custom['target_lang'],
"|TASK|": cut.custom['taskname'],
"|PNC|": cut.custom['pnc'],
"|PROMPT_LANGUAGE|": CANARY_SPECIAL_TOKENIZER,
},
),
dict(
role="assistant",
slots={
"|TEXT|": ' '.join(s.text for s in cut.supervisions),
"|PROMPT_LANGUAGE|": cut.custom["target_lang"],
},
),
]
)
tokens.append(encoded["input_ids"])
prompts.append(encoded["context_ids"])

# eos (only in training)
if tokens is not None:
prompted_tokens.append(tokenizer.eos_id)
return prompted_tokens
return tokens, prompts
pzelasko marked this conversation as resolved.
Show resolved Hide resolved


class ProbablyIncorrectLanguageKeyError(RuntimeError):
Expand Down
51 changes: 44 additions & 7 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
from nemo.collections.common.metrics import GlobalAverageLossMetric
from nemo.collections.common.parts import transformer_weights_init
from nemo.collections.common.parts.preprocessing.manifest import get_full_path
from nemo.collections.common.prompts.formatter import PromptFormatter
from nemo.collections.common.tokenizers.canary_tokenizer import CANARY_SPECIAL_TOKENIZER
from nemo.core.classes.common import typecheck
from nemo.core.neural_types import (
AudioSignal,
Expand Down Expand Up @@ -106,9 +108,33 @@
text_field: str = "answer"
lang_field: str = "target_lang"

def get_prompt_turns_with_slot_values(self, prompt_format: str) -> list[dict[str, str]]:
formatter = PromptFormatter.resolve(prompt_format)
# Canary requires special handling to satisfy prompt formatter API.
if prompt_format == "canary":
# First level of indirection: map trcfg field names to prompt formatter slot names.
slot_to_trcfg_key = {
"|TASK|": "task",
"|PNC|": "pnc",
"|SOURCE_LANG|": "source_lang",
"|TARGET_LANG|": "target_lang",
}
# Now ask the prompt formatter about which slots are required
# (we know it up-front for canary, but in a generic case this is how it would have looked like).
slots = formatter.get_slots(role="user")
for slot in slots:
slots[slot] = getattr(self, slot_to_trcfg_key[slot])
# Extra slot for aggregate tokenizer support.
slots["|PROMPT_LANGUAGE|"] = CANARY_SPECIAL_TOKENIZER
return [{"role": "user", "slots": slots}]
else:
raise NotImplementedError(
f"We dont support other prompt formats than 'canary' yet (received {prompt_format=})"
)

_internal: Optional[MultiTaskTranscriptionInternalConfig] = field(
default_factory=lambda: MultiTaskTranscriptionInternalConfig()
)

Check notice

Code scanning / CodeQL

Unnecessary lambda Note

This 'lambda' is just a simple wrapper around a callable object. Use that object directly.

def __post_init__(self):
required_fields = ['task', 'pnc', 'source_lang', 'target_lang', 'text_field', 'lang_field']
Expand Down Expand Up @@ -156,7 +182,7 @@
self.transf_encoder = EncDecMultiTaskModel.from_config_dict(transf_encoder_cfg_dict)

# Initialize weights
std_init_range = 1 / self.cfg.model_defaults.lm_enc_hidden ** 0.5
std_init_range = 1 / self.cfg.model_defaults.lm_enc_hidden**0.5
self.transf_encoder.apply(lambda module: transformer_weights_init(module, std_init_range))

transf_decoder_cfg_dict = cfg.transf_decoder
Expand All @@ -182,7 +208,7 @@
self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight

# Initialize weights
std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden ** 0.5
std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden**0.5
self.transf_decoder.apply(lambda module: transformer_weights_init(module, std_init_range))
self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range))

Expand Down Expand Up @@ -347,7 +373,7 @@
self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight

# Initialize weights of token classifier
std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden ** 0.5
std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden**0.5
self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range))

# Setup Decoding class
Expand Down Expand Up @@ -716,6 +742,8 @@
# Switch model to evaluation mode
self.transf_decoder.freeze()

self.prompt_formatter = PromptFormatter.resolve(self.prompt_format)(self.tokenizer)
pzelasko marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(audio, list):
logging.debug(f"Found 'audio' to be a list of {len(audio)} items.")
logging.debug(f"Assuming each item in 'audio' is a path to audio file.")
Expand All @@ -736,9 +764,6 @@
if hasattr(trcfg, '_internal') and hasattr(trcfg._internal, 'manifest_path'):
trcfg._internal.manifest_filepath = manifest_path

elif isinstance(audio, (np.ndarray, torch.Tensor)):
raise NotImplementedError("Transcribing from numpy or torch tensors is not supported yet.")

def _transcribe_input_manifest_processing(
self, audio_files: List[str], temp_dir: str, trcfg: MultiTaskTranscriptionConfig
) -> Dict[str, Any]:
Expand Down Expand Up @@ -790,7 +815,19 @@
log_probs, encoded_len, enc_states, enc_mask = self.forward(
input_signal=batch[0], input_signal_length=batch[1]
)
decoder_input_ids = batch[-2].to(trcfg._internal.device)
if len(batch) == 6:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting, it will need to be updated on the future when batch has more task inputs potentially. For now, it's fine

# Prompt provided by the dataloader.
decoder_input_ids = batch[4]
else:
# The dataloader provided only audio + audio_lens, so we
# are constructing the prompt dynamically using TranscribeConfig.
turns = trcfg.get_prompt_turns_with_slot_values(self.prompt_format)
decoder_input_ids = (
self.prompt_formatter.encode_dialog(turns=turns)["context_ids"]
.unsqueeze(0)
.repeat(batch[0].shape[0], 1)
.to(trcfg._internal.device)
)
output = dict(
log_probs=log_probs,
encoded_lengths=encoded_len,
Expand Down
8 changes: 3 additions & 5 deletions nemo/collections/asr/parts/mixins/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,9 @@ def get_item(self, index):
# Calculate seq length
seq_len = torch.tensor(samples.shape[0], dtype=torch.long)

# Dummy text tokens
text_tokens = torch.tensor([0], dtype=torch.long)
text_tokens_len = torch.tensor(1, dtype=torch.long)

return (samples, seq_len, text_tokens, text_tokens_len)
# Typically NeMo ASR models expect the mini-batch to be a 4-tuple of (audio, audio_len, text, text_len).
# For inference, we set text and text_len to None to not disrupt the shape of the tuple.
return samples, seq_len, None, None


class TranscriptionMixin(ABC):
Expand Down
Loading
Loading