diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index 1763c2035805..b63e9db5fef1 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -16,7 +16,7 @@ import glob import json import os -from dataclasses import dataclass, is_dataclass +from dataclasses import dataclass, field, is_dataclass from tempfile import NamedTemporaryFile from typing import List, Optional, Union @@ -25,6 +25,7 @@ from omegaconf import OmegaConf, open_dict from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel, EncDecMultiTaskModel +from nemo.collections.asr.models.aed_multitask_models import parse_multitask_prompt from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig @@ -169,6 +170,14 @@ class TranscriptionConfig: # Decoding strategy for AED models multitask_decoding: MultiTaskDecodingConfig = MultiTaskDecodingConfig() + # Prompt slots for prompted models, e.g. Canary-1B. Examples of acceptable prompt inputs: + # Implicit single-turn assuming default role='user' (works with Canary-1B) + # +prompt.source_lang=en +prompt.target_lang=es +prompt.task=asr +prompt.pnc=yes + # Explicit single-turn prompt: + # +prompt.role=user +prompt.slots.source_lang=en +prompt.slots.target_lang=es +prompt.slots.task=s2t_translation +prompt.slots.pnc=yes + # Explicit multi-turn prompt: + # +prompt.turns='[{role:user,slots:{source_lang:en,target_lang:es,task:asr,pnc:yes}}]' + prompt: dict = field(default_factory=dict) # decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models decoder_type: Optional[str] = None @@ -411,6 +420,8 @@ def autocast(dtype=None): override_cfg.augmentor = augmentor override_cfg.text_field = cfg.gt_text_attr_name override_cfg.lang_field = cfg.gt_lang_attr_name + if hasattr(override_cfg, "prompt"): + override_cfg.prompt = parse_multitask_prompt(OmegaConf.to_container(cfg.prompt)) transcriptions = asr_model.transcribe( audio=filepaths, override_config=override_cfg, diff --git a/nemo/collections/asr/data/audio_to_text.py b/nemo/collections/asr/data/audio_to_text.py index 00c15109b64f..e0bb63ad18cd 100644 --- a/nemo/collections/asr/data/audio_to_text.py +++ b/nemo/collections/asr/data/audio_to_text.py @@ -75,7 +75,9 @@ def _speech_collate_fn(batch, pad_id): has_audio = audio_lengths[0] is not None if has_audio: max_audio_len = max(audio_lengths).item() - max_tokens_len = max(tokens_lengths).item() + has_tokens = tokens_lengths[0] is not None + if has_tokens: + max_tokens_len = max(tokens_lengths).item() audio_signal, tokens = [], [] for b in batch: @@ -89,19 +91,24 @@ def _speech_collate_fn(batch, pad_id): pad = (0, max_audio_len - sig_len) sig = torch.nn.functional.pad(sig, pad) audio_signal.append(sig) - tokens_i_len = tokens_i_len.item() - if tokens_i_len < max_tokens_len: - pad = (0, max_tokens_len - tokens_i_len) - tokens_i = torch.nn.functional.pad(tokens_i, pad, value=pad_id) - tokens.append(tokens_i) + if has_tokens: + tokens_i_len = tokens_i_len.item() + if tokens_i_len < max_tokens_len: + pad = (0, max_tokens_len - tokens_i_len) + tokens_i = torch.nn.functional.pad(tokens_i, pad, value=pad_id) + tokens.append(tokens_i) if has_audio: audio_signal = torch.stack(audio_signal) audio_lengths = torch.stack(audio_lengths) else: audio_signal, audio_lengths = None, None - tokens = torch.stack(tokens) - tokens_lengths = torch.stack(tokens_lengths) + if has_tokens: + tokens = torch.stack(tokens) + tokens_lengths = torch.stack(tokens_lengths) + else: + tokens = None + tokens_lengths = None if sample_ids is None: return audio_signal, audio_lengths, tokens, tokens_lengths else: @@ -256,8 +263,7 @@ def cache_datastore_manifests( if num_datastore_manifests > 0: # Local utility function def cache_data(manifest_filepaths, cache_audio, num_workers, max_num_workers): - """Cache manifests and audio data from object store. - """ + """Cache manifests and audio data from object store.""" # Determine the number of workers to use if num_workers is None: num_workers = os.cpu_count() - 1 @@ -421,8 +427,7 @@ class _AudioTextDataset(Dataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), 'a_sig_length': NeuralType(tuple('B'), LengthsType()), @@ -546,8 +551,7 @@ class AudioToCharDataset(_AudioTextDataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), 'a_sig_length': NeuralType(tuple('B'), LengthsType()), @@ -640,8 +644,7 @@ class AudioToBPEDataset(_AudioTextDataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), 'a_sig_length': NeuralType(tuple('B'), LengthsType()), @@ -910,8 +913,7 @@ def __next__(self): return TarredAudioFilter(self.manifest_processor.collection) def _loop_offsets(self, iterator): - """This function is used to iterate through utterances with different offsets for each file. - """ + """This function is used to iterate through utterances with different offsets for each file.""" class TarredAudioLoopOffsets: def __init__(self, collection): @@ -944,8 +946,7 @@ def _collate_fn(self, batch): return _speech_collate_fn(batch, self.pad_id) def _build_sample(self, tup): - """Builds the training sample by combining the data from the WebDataset with the manifest info. - """ + """Builds the training sample by combining the data from the WebDataset with the manifest info.""" audio_bytes, audio_filename, offset_id = tup # Grab manifest entry from self.manifest_preprocessor.collection @@ -1316,7 +1317,9 @@ class BucketingDataset(IterableDataset): """ def __init__( - self, dataset: IterableDataset, bucketing_batch_size: int, + self, + dataset: IterableDataset, + bucketing_batch_size: int, ): self.wrapped_dataset = dataset self.bucketing_batch_size = bucketing_batch_size diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py index 000b1a8f0839..e9e97d3d32d7 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Callable, Sequence -import omegaconf import torch.utils.data from lhotse import CutSet from lhotse.cut import MixedCut, MonoCut @@ -21,7 +20,9 @@ from lhotse.dataset.collation import collate_vectors from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper +from nemo.collections.common.prompts.canary import CanaryPromptFormatter from nemo.collections.common.tokenizers import CanaryTokenizer, TokenizerSpec +from nemo.collections.common.tokenizers.canary_tokenizer import CANARY_SPECIAL_TOKENIZER class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset): @@ -57,21 +58,21 @@ def __init__( def __getitem__(self, cuts: CutSet) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: audio, audio_lens, cuts = self.load_audio(cuts) - tokens, prompt_tokens = self.prompt_format_fn(cuts, self.tokenizer, inference=self.inference) + prompts_with_answers, prompts = self.prompt_format_fn(cuts, self.tokenizer, inference=self.inference) - tokens = [torch.as_tensor(t) for t in tokens] - token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) - tokens = collate_vectors(tokens, padding_value=self.padding_value) + prompts_with_answers = [torch.as_tensor(t) for t in prompts_with_answers] + prompts_with_answers_lens = torch.tensor([t.size(0) for t in prompts_with_answers], dtype=torch.long) + prompts_with_answers = collate_vectors(prompts_with_answers, padding_value=self.padding_value) if self.inference: - prompt_tokens = [torch.as_tensor(t) for t in prompt_tokens] - prompt_token_lens = torch.tensor([t.size(0) for t in prompt_tokens], dtype=torch.long) - prompt_tokens = collate_vectors(prompt_tokens, padding_value=self.padding_value) + prompts = [torch.as_tensor(t) for t in prompts] + prompts_lens = torch.tensor([t.size(0) for t in prompts], dtype=torch.long) + prompts = collate_vectors(prompts, padding_value=self.padding_value) else: - prompt_tokens = None - prompt_token_lens = None + prompts = None + prompts_lens = None - return audio, audio_lens, tokens, token_lens, prompt_tokens, prompt_token_lens + return audio, audio_lens, prompts_with_answers, prompts_with_answers_lens, prompts, prompts_lens # Mapping from a string name to a known prompt formatter function. @@ -105,7 +106,9 @@ def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper, bool] @registered_prompt_format_fn -def canary(cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False) -> Sequence[Sequence[int]]: +def canary( + cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: """ Prepend and append control tokens to the token sequence as per Canary format. @@ -132,116 +135,53 @@ def canary(cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False) - assert isinstance( tokenizer._tokenizer, CanaryTokenizer ), "To use 'canary' prompt format, you must use the CanaryTokenizer." - tokenizer = tokenizer._tokenizer + formatter = CanaryPromptFormatter(tokenizer._tokenizer) - tokens, prompts = [], [] + prompts_with_answers, prompts = [], [] for cut in cuts: if isinstance(cut, MixedCut): cut = cut._first_non_padding_cut - assert isinstance(cut, MonoCut), "Expected MonoCut." + if not isinstance(cut, MonoCut): + raise TypeError( + f"Expected input audio to have a single channel (required MonoCut/MixedCut, but we received: {cut=})" + ) # first, validate the utterance - missing_keys = [k for k in ("source_lang", "target_lang", "taskname", "pnc") if k not in cut.custom] + expected_slots = set(formatter.get_slots("user")) + missing_keys = expected_slots - set(cut.custom) + if "task" in missing_keys and "taskname" in cut.custom: + # Compatibility with "old" Canary manifest format. + # For compatbility with inference options, this slot is now called "task". + cut.custom["task"] = cut.custom["taskname"] + missing_keys.remove("task") if missing_keys: raise RuntimeError( f"We found cut with ID {cut.id} that is missing the following keys: {missing_keys}" f"Please ensure that every utterance in the input manifests contains these keys." ) - # Actual tokenization. If a cut has multiple supervisions, we'll stitch their tokenized texts together. - texts = [sup.text for sup in cut.supervisions] - langs = [sup.language for sup in cut.supervisions] - taskname = cut.custom['taskname'] - pnc = cut.custom['pnc'] - source_lang = cut.custom['source_lang'] - target_lang = cut.custom['target_lang'] - - tokens.append(canary_prompt(tokenizer, texts, langs, source_lang, target_lang, taskname, pnc)) - if inference: - prompts.append(canary_prompt(tokenizer, None, None, source_lang, target_lang, taskname, pnc)) - return tokens, prompts - - -def canary_prompt( - tokenizer: CanaryTokenizer, - text: str | list[str] | None, - language: str | list[str] | None, - source_language: str, - target_language: str, - taskname: str, - pnc: str, -) -> list[int]: - if isinstance(text, str): - text = [text] - if isinstance(language, str): - language = [language] - - if text is not None: - try: - tokens = sum((tokenizer.text_to_ids(text_, lang_) for text_, lang_ in zip(text, language)), start=[]) - except omegaconf.errors.KeyValidationError as e: - raise ProbablyIncorrectLanguageKeyError( - "We couldn't select the right tokenizer, which could be due to issues with reading " - "the language from the manifest. " - "If you're training, try setting lang_field='' to a different value (probably 'target_lang' or 'lang'). " - "If you're using model.transcribe() directly, please use override_config kwarg to set this. " - "If you're using transcribe_speech.py, use option gt_lang_attr_name='...' " - ) from e - else: - tokens = None # create prompt for inference - - # bos - prompted_tokens = [tokenizer.bos_id] - - if tokens is not None and len(tokens) == 0: - # no speech token - prompted_tokens.append(tokenizer.nospeech_id) - else: - # first, validate the utterance - if source_language is None or target_language is None or taskname is None or pnc is None: - raise RuntimeError( - f"Missing keys provided to prompt: " - f"source_langauge={source_language},\n" - f"target_language={target_language},\n" - f"taskname={taskname},\n" - f"pnc={pnc}\n" - f"Please ensure that every utterance in the input manifests contains these keys." - ) - - # src_lang_id/no_speech - src_lang_id = tokenizer.spl_token_to_id(source_language) - prompted_tokens.append(src_lang_id) - - # task - task = taskname - if task == 'asr' or task == "transcribe": - prompted_tokens.append(tokenizer.spl_token_to_id("transcribe")) - elif task == 's2t_translation' or task == 'ast' or task == "translate": - prompted_tokens.append(tokenizer.spl_token_to_id("translate")) - else: - raise ValueError(f"Unknown task: {task}") - - # tgt_lang_id - tgt_lang_id = tokenizer.spl_token_to_id(target_language) - prompted_tokens.append(tgt_lang_id) - - # PnC - pnc = f"{pnc}".lower().strip() # to account for bool or str - if pnc in {'yes', 'true'}: - prompted_tokens.append(tokenizer.spl_token_to_id("pnc")) - elif pnc in {'no', 'false'}: - prompted_tokens.append(tokenizer.spl_token_to_id("nopnc")) - else: - raise ValueError(f"Unknown value for key 'pnc': {pnc}") - - # text (only in training) - if tokens is not None: - prompted_tokens.extend(tokens) + encoded = formatter.encode_dialog( + turns=[ + dict( + role="user", + slots={ + **{slot: cut.custom[slot] for slot in expected_slots}, + formatter.PROMPT_LANGUAGE_SLOT: CANARY_SPECIAL_TOKENIZER, + }, + ), + dict( + role="assistant", + slots={ + "text": ' '.join(s.text for s in cut.supervisions), + formatter.PROMPT_LANGUAGE_SLOT: cut.custom["target_lang"], + }, + ), + ] + ) + prompts_with_answers.append(encoded["input_ids"]) + prompts.append(encoded["context_ids"]) - # eos (only in training) - if tokens is not None: - prompted_tokens.append(tokenizer.eos_id) - return prompted_tokens + return prompts_with_answers, prompts class ProbablyIncorrectLanguageKeyError(RuntimeError): diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index b11d744a7e6a..880f8bb3a004 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import warnings from dataclasses import dataclass, field from math import ceil from typing import Any, Dict, List, Optional, Union @@ -45,6 +46,7 @@ from nemo.collections.common.metrics import GlobalAverageLossMetric from nemo.collections.common.parts import transformer_weights_init from nemo.collections.common.parts.preprocessing.manifest import get_full_path +from nemo.collections.common.prompts.formatter import PromptFormatter from nemo.core.classes.common import typecheck from nemo.core.neural_types import ( AudioSignal, @@ -100,10 +102,7 @@ class MultiTaskTranscriptionConfig(TranscribeConfig): Configuration for Multi Task Transcription """ - task: Optional[str] = None - pnc: Optional[bool] = None - source_lang: Optional[str] = None - target_lang: Optional[str] = None + prompt: list[dict[str, dict[str, str]]] | None = None text_field: str = "answer" lang_field: str = "target_lang" @@ -112,10 +111,7 @@ class MultiTaskTranscriptionConfig(TranscribeConfig): ) def __post_init__(self): - required_fields = ['task', 'pnc', 'source_lang', 'target_lang', 'text_field', 'lang_field'] - for field in required_fields: - if not hasattr(self, field): - raise ValueError(f"`{field}` must be present in the transcription config: {self}") + self.prompt = parse_multitask_prompt(self.prompt) class EncDecMultiTaskModel(ASRModel, ExportableEncDecModel, ASRBPEMixin, ASRTranscriptionMixin): @@ -134,6 +130,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) + prompt_cls = PromptFormatter.resolve(self.prompt_format) + self.prompt = prompt_cls( + tokenizer=self.tokenizer, + defaults=OmegaConf.to_container(cfg.get("prompt_defaults")), + ) + # Setup audio preprocessor self.preprocessor = EncDecMultiTaskModel.from_config_dict(self.cfg.preprocessor) # Setup audio encoder @@ -391,15 +393,12 @@ def transcribe( audio: Union[str, List[str], np.ndarray, DataLoader], batch_size: int = 4, return_hypotheses: bool = False, - task: Optional[str] = None, - pnc: Optional[bool] = None, - source_lang: Optional[str] = None, - target_lang: Optional[str] = None, num_workers: int = 0, channel_selector: Optional[ChannelSelectorType] = None, augmentor: DictConfig = None, verbose: bool = True, override_config: Optional[MultiTaskTranscriptionConfig] = None, + **prompt, ) -> Union[List[str], List[Hypothesis]]: """ Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. @@ -412,15 +411,12 @@ def transcribe( Bigger will result in better throughput performance but would use more memory. return_hypotheses: (bool) Either return hypotheses or text With hypotheses can do some postprocessing like getting timestamp or rescoring - task: (str) task name. Defaults to `asr`. - pnc: (bool) whether to apply punctuation and capitalization or not. Defaults to True. - source_lang: (str) source language. Defaults to `en`. - target_lang: (str) target language. Defaults to `en`. num_workers: (int) number of workers for DataLoader channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied. verbose: (bool) whether to display tqdm progress bar override_config: (Optional[MultiTaskTranscriptionConfig]) A config to override the default config. + **prompt: Optional input to construct the prompts for the model. Accepted formats are: 1) legacy Canary-1B API source_lang=, target_lang=, etc. 2) explicit single-turn role=, slots={: , ...} 3) explicit multi-turn: turns=[{"role": , "slots": {: , ...}}] Returns: A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files @@ -433,10 +429,7 @@ def transcribe( channel_selector=channel_selector, augmentor=augmentor, verbose=verbose, - task=task, - pnc=pnc, - source_lang=source_lang, - target_lang=target_lang, + prompt=prompt, ) else: if not isinstance(override_config, MultiTaskTranscriptionConfig): @@ -738,9 +731,6 @@ def _transcribe_on_begin(self, audio, trcfg: MultiTaskTranscriptionConfig): if hasattr(trcfg, '_internal') and hasattr(trcfg._internal, 'manifest_path'): trcfg._internal.manifest_filepath = manifest_path - elif isinstance(audio, (np.ndarray, torch.Tensor)): - raise NotImplementedError("Transcribing from numpy or torch tensors is not supported yet.") - def _transcribe_input_manifest_processing( self, audio_files: List[str], temp_dir: str, trcfg: MultiTaskTranscriptionConfig ) -> Dict[str, Any]: @@ -792,7 +782,47 @@ def _transcribe_forward(self, batch: Any, trcfg: MultiTaskTranscriptionConfig): log_probs, encoded_len, enc_states, enc_mask = self.forward( input_signal=batch[0], input_signal_length=batch[1] ) - decoder_input_ids = batch[-2].to(trcfg._internal.device) + if len(batch) == 6: + # Prompt provided by the dataloader. + decoder_input_ids = batch[4] + else: + # The dataloader provided only audio + audio_lens, so we + # are constructing the prompt dynamically using TranscribeConfig. + + # Now ask the prompt formatter about which slots are required. + # It will return a default prompt structure with default slot values (if available, None otherwise). + # We iterate over that structure and update slot values based on ``trcfg.prompt``. + default_turns = self.prompt.get_default_dialog_slots() + if not trcfg.prompt: + # No turns were provided, use defaults. + turns = default_turns + else: + # Turns were provided, iterate over them and fill missing slot values using defaults.. + turns = trcfg.prompt.copy() # shallow copy #1: don't override the config + for turn in turns: + role = turn["role"] + # Check if we have defaults for this role. + # There shouldn't be more than a single turn for a given role, but if there are, + # we'll emit a warning. + if default_turns_for_role := [t for t in default_turns if t["role"] == role]: + if len(default_turns_for_role) > 1: + warnings.warn( + f"More than one default turn detected for {role=}. " + f"We'll be using default slot values for the first turn of {role=} only." + ) + default_slots = default_turns_for_role[0]["slots"] + turn["slots"] = turn["slots"].copy() # shallow copy #1: don't override the config + # fill missing slots using defaults + for slot, val in default_slots.items(): + if turn["slots"].get(slot) is None: + turn["slots"][slot] = val + + decoder_input_ids = ( + self.prompt.encode_dialog(turns=turns)["context_ids"] + .unsqueeze(0) + .repeat(batch[0].shape[0], 1) + .to(trcfg._internal.device) + ) output = dict( log_probs=log_probs, encoded_lengths=encoded_len, @@ -906,6 +936,8 @@ def _may_be_make_dict_and_fix_paths(self, json_items, manifest_path, trcfg: Mult Returns: A list of dictionaries with the audio file paths fixed. """ + # This method is a legacy helper for Canary that checks whether prompt slot values were provided + # in the input manifest and if not, it injects the defaults. out_json_items = [] for item in json_items: if isinstance(item, str): @@ -913,28 +945,21 @@ def _may_be_make_dict_and_fix_paths(self, json_items, manifest_path, trcfg: Mult entry = { 'audio_filepath': item, 'duration': 100000, - 'source_lang': 'en' if trcfg.source_lang is None else trcfg.source_lang, - 'taskname': 'asr' if trcfg.task is None else trcfg.task, - 'target_lang': 'en' if trcfg.target_lang is None else trcfg.target_lang, - 'pnc': 'yes' if trcfg.pnc is None else 'yes' if trcfg.pnc else 'no', trcfg.text_field: 'nothing', } elif isinstance(item, dict): entry = item entry['audio_filepath'] = get_full_path(entry['audio_filepath'], manifest_file=manifest_path) - - if 'source_lang' not in entry: - entry['source_lang'] = 'en' if trcfg.source_lang is None else trcfg.source_lang - if 'taskname' not in entry: - entry['taskname'] = 'asr' if trcfg.task is None else trcfg.task - if 'target_lang' not in entry: - entry['target_lang'] = 'en' if trcfg.target_lang is None else trcfg.target_lang - if 'pnc' not in entry: - entry['pnc'] = 'yes' if trcfg.pnc is None else 'yes' if trcfg.pnc else 'no' if trcfg.text_field not in entry: entry[trcfg.text_field] = 'nothing' else: raise ValueError(f"Expected str or dict, got {type(item)}") + default_turn = [t for t in trcfg.prompt if t["role"] == "user"] + default_turn = default_turn[0]["slots"] if default_turn else {} + for k, dv in (("source_lang", "en"), ("target_lang", "en"), ("taskname", "asr"), ("pnc", "yes")): + if k not in entry: + # last-chance fallback injecting legacy Canary defaults if none were provided. + entry[k] = default_turn.get(k, dv) out_json_items.append(entry) return out_json_items @@ -977,3 +1002,76 @@ def predict_step(self, batch, batch_idx=0, dataloader_idx=0, has_processed_signa text = [self.decoding.strip_special_tokens(t) for t in text] return text + + +def parse_multitask_prompt(prompt: dict | None) -> list[dict]: + if prompt is None or not prompt: + return [] + + # Case 1. + # Multi-turn prompting format. This format conforms to PromptFormatter API and needs no further modification. + # This format allows to condition the model on chat history, system+user prompts, etc. + # Example: + # model.transcribe( + # audio, + # turns=[ + # dict( + # role="user", + # slots=dict( + # source_lang='en', target_lang='de', task='asr', pnc=True, context='translate this text' + # ), + # ), + # dict( + # role="assistant", + # slots=dict(message="Calculating the translation of given text. Do you want to proceed?"), + # ), + # dict( + # role="user", + # slots=dict( + # source_lang='en', target_lang='de', task='asr', pnc=True, context='Yes, please proceed.' + # ), + # ), + # ], + # ) + if 'turns' in prompt: + assert ( + len(prompt) == 1 + and isinstance(prompt["turns"], list) + and all(isinstance(t, dict) and "role" in t and "slots" in t for t in prompt["turns"]) + ), ( + f"When providing a multi-turn prompt through 'turns', no other keys are allowed " + f"and the value under prompt['turns'] must be a list of dicts with roles and slot values " + f"(we received {prompt=})" + ) + return prompt["turns"] + + values_are_dicts = any(isinstance(v, dict) for k, v in prompt.items() if k != "slots") + assert not values_are_dicts, ( + f"We don't support dict values for prompt keys other than 'slots'. " f"We received {prompt=}" + ) + + # Case 2. + # Single-turn prompting format with explicitly provided role and slot names and values. + # We create a 1-item multi-turn prompt from this input. + # Example: + # model.transcribe( + # audio, + # role="user", + # slots=dict(source_lang='en', target_lang='de', task='asr', pnc=True, context='translate this text'), + # ) + if "role" in prompt and "slots" in prompt: + assert isinstance(prompt["slots"], dict), ( + f"When providing a single-turn prompt through 'role', 'slots' must also be provided " + f"(we received {prompt=})." + ) + return [prompt] + + # Case 3. + # Legacy prompting format for Canary-1B preserved for backward compatibility. + # Extra fields are converted to a single-turn prompt with role "user" (unless overridden with 'role'). + # Example: + # model.transcribe( + # audio, pnc=True, source_lang='en', target_lang='de', task='asr', context='translate this text' + # ) + role = prompt.pop("role", "user") + return [dict(role=role, slots=prompt)] diff --git a/nemo/collections/asr/parts/mixins/transcription.py b/nemo/collections/asr/parts/mixins/transcription.py index df8d6bac50a9..261e97a225dd 100644 --- a/nemo/collections/asr/parts/mixins/transcription.py +++ b/nemo/collections/asr/parts/mixins/transcription.py @@ -148,11 +148,9 @@ def get_item(self, index): # Calculate seq length seq_len = torch.tensor(samples.shape[0], dtype=torch.long) - # Dummy text tokens - text_tokens = torch.tensor([0], dtype=torch.long) - text_tokens_len = torch.tensor(1, dtype=torch.long) - - return (samples, seq_len, text_tokens, text_tokens_len) + # Typically NeMo ASR models expect the mini-batch to be a 4-tuple of (audio, audio_len, text, text_len). + # For inference, we set text and text_len to None to not disrupt the shape of the tuple. + return samples, seq_len, None, None class TranscriptionMixin(ABC): diff --git a/nemo/collections/asr/parts/utils/streaming_utils.py b/nemo/collections/asr/parts/utils/streaming_utils.py index 71c945b66255..51a46184e66f 100644 --- a/nemo/collections/asr/parts/utils/streaming_utils.py +++ b/nemo/collections/asr/parts/utils/streaming_utils.py @@ -21,7 +21,6 @@ from omegaconf import OmegaConf from torch.utils.data import DataLoader -from nemo.collections.asr.data.audio_to_text_lhotse_prompted import canary_prompt from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder from nemo.collections.asr.parts.preprocessing.features import normalize_batch @@ -444,7 +443,10 @@ def _convert_buffer_to_features(self): device = self.asr_model.device audio_signal = samples.unsqueeze_(0).to(device) audio_signal_len = torch.Tensor([samples.shape[1]]).to(device) - features, features_len = self.raw_preprocessor(input_signal=audio_signal, length=audio_signal_len,) + features, features_len = self.raw_preprocessor( + input_signal=audio_signal, + length=audio_signal_len, + ) features = features.squeeze() self._update_feature_buffer(features[:, -self.feature_chunk_len :]) @@ -479,7 +481,10 @@ def __init__(self, samples, frame_len, preprocessor, device, pad_to_frame_len=Tr self._feature_frame_len = frame_len / timestep_duration audio_signal = torch.from_numpy(self._samples).unsqueeze_(0).to(device) audio_signal_len = torch.Tensor([self._samples.shape[0]]).to(device) - self._features, self._features_len = preprocessor(input_signal=audio_signal, length=audio_signal_len,) + self._features, self._features_len = preprocessor( + input_signal=audio_signal, + length=audio_signal_len, + ) self._features = self._features.squeeze() def __iter__(self): @@ -701,7 +706,12 @@ class for streaming frame-based ASR use reset() method to reset FrameASR's """ def __init__( - self, asr_model, frame_len=1.6, total_buffer=4.0, batch_size=4, pad_to_buffer_len=True, + self, + asr_model, + frame_len=1.6, + total_buffer=4.0, + batch_size=4, + pad_to_buffer_len=True, ): ''' Args: @@ -1183,7 +1193,9 @@ def _get_batch_preds(self): del best_hyp, pred def transcribe( - self, tokens_per_chunk: int, delay: int, + self, + tokens_per_chunk: int, + delay: int, ): """ Performs "middle token" alignment prediction using the buffered audio chunk. @@ -1210,7 +1222,12 @@ def transcribe( ids, toks = self._alignment_decoder(alignment, self.asr_model.tokenizer, self.blank_id) if len(ids) > 0 and a_idx < signal_end_idx: - self.unmerged[idx] = inplace_buffer_merge(self.unmerged[idx], ids, delay, model=self.asr_model,) + self.unmerged[idx] = inplace_buffer_merge( + self.unmerged[idx], + ids, + delay, + model=self.asr_model, + ) output = [] for idx in range(self.batch_size): @@ -1276,7 +1293,9 @@ def __init__( self.alignment_basepath = alignment_basepath def transcribe( - self, tokens_per_chunk: int, delay: int, + self, + tokens_per_chunk: int, + delay: int, ): if self.lcs_delay < 0: raise ValueError( @@ -1302,7 +1321,10 @@ def transcribe( if len(ids) > 0: self.unmerged[idx] = inplace_buffer_merge( - self.unmerged[idx], ids, delay, model=self.asr_model, + self.unmerged[idx], + ids, + delay, + model=self.asr_model, ) else: @@ -1588,15 +1610,17 @@ def get_input_tokens(self, sample: dict): f"We found sample that is missing the following keys: {missing_keys}" f"Please ensure that every utterance in the input manifests contains these keys. Sample: {sample}" ) - tokens = canary_prompt( - tokenizer=self.asr_model.tokenizer, - text=None, - language=None, - source_language=sample['source_lang'], - target_language=sample['target_lang'], - taskname=sample['taskname'], - pnc=sample['pnc'], - ) + tokens = self.asr_model.prompt.encode_dialog( + turns=[ + { + "role": "user", + "slots": { + **sample, + self.asr_model.prompt.PROMPT_LANGUAGE_SLOT: "spl_tokens", + }, + } + ] + )["context_ids"] else: raise ValueError(f"Unknown prompt format: {self.asr_model.prompt_format}") return torch.tensor(tokens, dtype=torch.long, device=self.asr_model.device).unsqueeze(0) # [1, T] @@ -1712,12 +1736,16 @@ def _get_batch_preds(self, keep_logits=False): encoded, encoded_len = results log_probs = self.asr_model.ctc_decoder(encoder_output=encoded) transcribed_texts, _ = self.asr_model.ctc_decoding.ctc_decoder_predictions_tensor( - decoder_outputs=log_probs, decoder_lengths=encoded_len, return_hypotheses=False, + decoder_outputs=log_probs, + decoder_lengths=encoded_len, + return_hypotheses=False, ) else: log_probs, encoded_len, predictions = results transcribed_texts, _ = self.asr_model.decoding.ctc_decoder_predictions_tensor( - decoder_outputs=log_probs, decoder_lengths=encoded_len, return_hypotheses=False, + decoder_outputs=log_probs, + decoder_lengths=encoded_len, + return_hypotheses=False, ) self.all_preds.extend(transcribed_texts) diff --git a/nemo/collections/common/prompts/__init__.py b/nemo/collections/common/prompts/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/collections/common/prompts/canary.py b/nemo/collections/common/prompts/canary.py new file mode 100644 index 000000000000..aadc976ba474 --- /dev/null +++ b/nemo/collections/common/prompts/canary.py @@ -0,0 +1,71 @@ +from nemo.collections.common.prompts.formatter import Modality, PromptFormatter +from nemo.collections.common.tokenizers.canary_tokenizer import ( + CANARY_BOS, + CANARY_EOS, + CANARY_NOPNC, + CANARY_PNC, + CANARY_SPECIAL_TOKENIZER, +) + + +class CanaryPromptFormatter(PromptFormatter): + NAME = "canary" + OUTPUT_ROLE = "assistant" + TEMPLATE = { + "user": { + "template": f"{CANARY_BOS}|source_lang||task||target_lang||pnc|", + "slots": { + "source_lang": Modality.Text, + "task": Modality.Text, + "target_lang": Modality.Text, + "pnc": Modality.Text, + }, + }, + OUTPUT_ROLE: { + "template": f"|text|{CANARY_EOS}", + "slots": { + "text": Modality.Text, + }, + }, + } + + def encode_turn(self, prompt_template: str, expected_slots: dict, slot_values: dict) -> list[int]: + # This method handles a level of indirection for Canary. + # It maps values provided in trcfg to the actual special tokens + # expected to be present in canary prompt. + # It used to be done in prompt_format_fn inside Dataset class corresponding to Canary, + # but we are not using it here anymore. + # This maps things such as '|task|: "asr"' to '|TASK|: "<|transcribe|>"'. + slot_values = map_manifest_values_to_special_tokens(slot_values) + return super().encode_turn( + prompt_template=prompt_template, expected_slots=expected_slots, slot_values=slot_values + ) + + +def map_manifest_values_to_special_tokens(slot_values: dict[str, str]) -> dict[str, str]: + slot_values = slot_values.copy() + + any_special_token_present = False + + for k in ("source_lang", "target_lang"): + if k in slot_values and not ((v := slot_values[k]).startswith("<|") and v.endswith("|>")): + slot_values[k] = "<|" + slot_values[k] + "|>" + any_special_token_present = True + + k = "pnc" + if k in slot_values and slot_values[k] not in (CANARY_PNC, CANARY_NOPNC): + slot_values[k] = CANARY_PNC if slot_values[k] in ("yes", "1", "True", "true") else CANARY_NOPNC + any_special_token_present = True + + # Note: we re-map 'taskname' to 'task' for compatibility with earlier versions of Canary training. + for k in ("task", "taskname"): + if k in slot_values and slot_values[k] not in ("<|transcribe|>", "<|translate|>"): + slot_values["task"] = "<|transcribe|>" if slot_values[k] == "asr" else "<|translate|>" + any_special_token_present = True + + # Auto-inject which tokenizer to look up in CanaryTokenizer if not provided, + # and slots for this turn correspond to user prompt. + if any_special_token_present and PromptFormatter.PROMPT_LANGUAGE_SLOT not in slot_values: + slot_values[PromptFormatter.PROMPT_LANGUAGE_SLOT] = CANARY_SPECIAL_TOKENIZER + + return slot_values diff --git a/nemo/collections/common/prompts/example.py b/nemo/collections/common/prompts/example.py new file mode 100644 index 000000000000..3589efb938f4 --- /dev/null +++ b/nemo/collections/common/prompts/example.py @@ -0,0 +1,36 @@ +""" +Implemented following the guide at https://www.promptingguide.ai/models/phi-2#phi-2-usage +""" + +from nemo.collections.common.prompts.formatter import Modality, PromptFormatter + + +class ExamplePromptFormatter(PromptFormatter): + """ + The simplest possible prompt formatter implementation. + + It defines a dialog of the form: + + User: Hi. + Assistant: Hi, how can I help you? + User: What's the time? + Assistant: It's 9 o'clock. + + """ + + NAME = "example_prompt_format" + OUTPUT_ROLE = "assistant" + TEMPLATE = { + "user": { + "template": f"User: |message|\n", + "slots": { + "message": Modality.Text, + }, + }, + OUTPUT_ROLE: { + "template": f"Assistant: |message|\n", + "slots": { + "message": Modality.Text, + }, + }, + } diff --git a/nemo/collections/common/prompts/formatter.py b/nemo/collections/common/prompts/formatter.py new file mode 100644 index 000000000000..524b2e62c5a3 --- /dev/null +++ b/nemo/collections/common/prompts/formatter.py @@ -0,0 +1,347 @@ +from abc import ABC +from enum import Enum +from functools import lru_cache +from typing import Any, Type + +import torch + +from nemo.collections.common.tokenizers import AggregateTokenizer, TokenizerSpec + +PREAMBLE_ROLE = "preamble" + +# Slots used to define when special tokens bos/eos should be inserted. +# These are special in the sense of how sentencepiece defines special tokens: +# They have to be specially inserted into the token sequence, and if they appear in the tokenized string, +# SPE wouldn't use the special token ids but rather tokenize them as if they were normal strings. +# We mimic SPE's behavior if these special slots are present in the template definition. +# To achieve that, insert |bos| / |eos| at the beginning/end of template. +# E.g., inserting only bos in llama2 user role: "template": "|bos|[INST] |message| [\INST]" +BOS_SLOT = "|bos|" +EOS_SLOT = "|eos|" + + +class Modality(Enum): + """ + Modalities supported as PromptFormatter slot values. + """ + + Text = "text" + + def matches(self, value: Any) -> bool: + """ + Checks if the provided value is compatible with an instance of Modality. + """ + match self: + case Modality.Text: + return isinstance(value, str) + case _: + return False + + +class PromptFormatter(ABC): + """ + :class:`~nemo.collections.common.prompts.formatter.PromptFormatter` is intended to simplify + working with various prompt format templates and encoding them into token ID tensors. + + It assumes a dialog-like structure, which is a list of turns, with each turn assigned to a role. + Sub-classes of PromptFormatter define turn templates for each role under TEMPLATE class attribute. + Each template may define some constant parts (e.g. begin-of-turn or end-of-turn tokens, whitespaces, etc.) + and variable parts which we call "slots", that will be provided by the user during training or inference. + + A role is typically "user" and "assistant", and some popular models also use a "system" role. + Other roles may be defined as well. We expect the role corresponding to the model's responses + will be registered under class attribute called OUTPUT_ROLE. + We reserve a special "preamble" role with no slots that will be inserted at the beginning of + the formatted prompt, if "preamble" is present in TEMPLATE. + + A turn is a dict with keys "role" and "slots", where "slots" are a dict that maps slot names + to values that should be filled in the template. + For example, a user role template may be ``"Question: |message|"`` and corresponding ``slots`` would then be + ``{"message": "What time is it?"}``. + + There is a special slot called ``|prompt_language|`` that's used to select the sub-tokenizer in + :class:`~nemo.collections.common.tokenizers.aggregate_tokenizer.AggregateTokenizer`. + It's only used when the tokenizer is aggregate; otherwise it's discarded. + + PromptFormatter supports constructing prompts for training (complete context and answers) + and for inference (context-only). + Training/inference is determined automatically; if the last role in a dialog is the OUTPUT_ROLE, + that's an 'asked-and-answered' scenario, so we assume it's inteded for training. + We'll create a dict with tokenized results available under the following keys: + + * ``context_ids`` (all turns minus last one), + * ``answer_ids`` (last turn) + * ``input_ids`` (previous two values concatenated) + * ``mask`` (boolean mask tensor of the same lenth as ``input_ids`` that's set to True on OUTPUT_ROLE turns) + + Typically, the user will use the ``encode_dialog`` method providing a list of turns to it. + Example showing how to construct model inputs/outputs for training:: + + >>> formatter = PromptFormatter(tokenizer) + ... encoded_for_training = formatter.encode_dialog( + ... turns=[ + ... {"role": "user", "slots": {"message": "What time is it?"}}, + ... {"role": "assistant", "slots": {"message": "Ten o'clock."}}, + ... {"role": "user", "slots": {"message": "PM or AM?"}}, + ... {"role": "assistant", "slots": {"message": "AM, naturally! It's bright outside"}}, + ... ] + ... ) + + Another example that shows how to use the same method to generate prompts for inference:: + + + >>> formatter = PromptFormatter(tokenizer) + ... encoded_for_training = formatter.encode_dialog( + ... turns=[ + ... {"role": "user", "slots": {"message": "What time is it?"}}, + ... {"role": "assistant", "slots": {"message": "Ten o'clock."}}, + ... {"role": "user", "slots": {"message": "PM or AM?"}}, + ... ] + ... ) + + """ + + # Used to support AggregateTokenizer; this key selects the right sub-tokenizer for each turn. + PROMPT_LANGUAGE_SLOT = "prompt_language" + + # Subclasses will be registered under this name, to be used via PromptFormatter.resolve(name). + NAME = None + + # Template is a dict that maps: + # * from a role name string (system/user/assistant/etc) + # * to a dict with keys + # * "template" that has a string value (the prompt template) + # * "slots" that has a value of dict[str, Modality] + # * keys of slots are the names of formattable slots in the prompt template + # * values of slots are :class:`Modality` objects that can be used to check + # whether a specific value conforms to a given modality requirements + # (e.g., Modality.Text may expect string objects). + # Template is intended to be defined by the child classes. + TEMPLATE = None + + # Turns under this role indicate responses by the model; if the last turn in + # PromptFormatter.encode_dialog() ends with this role, it indicates a training example. + OUTPUT_ROLE = None + + # Internal reserved field. + _REGISTERED_FORMATTERS = {} + + def __init__(self, tokenizer: TokenizerSpec, defaults: list[dict] | None = None) -> None: + self.tokenizer = tokenizer + self._defaults = defaults if defaults is not None else [] + self._validate_defaults() + + def __init_subclass__(cls, **kwargs) -> None: + ERR = "PromptFormatter subclass definition error:" + if cls.__name__ not in cls._REGISTERED_FORMATTERS: + for attr in ("NAME", "TEMPLATE", "OUTPUT_ROLE"): + assert ( + getattr(cls, attr, None) is not None + ), f"{ERR} PromptFormatter subclass {cls} did not define a class attribute {attr}" + assert cls.NAME not in cls._REGISTERED_FORMATTERS, ( + f"Cannot register {cls.__name__} under {cls.NAME}: another prompt formatter of type " + f"{cls._REGISTERED_FORMATTERS[cls.NAME]} has already been registered under this name." + ) + cls._REGISTERED_FORMATTERS[cls.NAME] = cls + if "preamble" in cls.TEMPLATE: + assert ( + len(cls.TEMPLATE["preamble"].get("slots", [])) == 0 + ), f"{ERR} Slots are not allowed for preamble template, but we found: '{cls.TEMPLATE['preamble']}'" + for role in cls.get_roles(): + template = cls.get_template(role) + for slot in cls.get_slots(role): + assert ( + _mangled(slot) in template + ), f"{ERR} Slot '{slot}' not found in template '{template}' for role '{role}'" + super().__init_subclass__(**kwargs) + + @classmethod + def resolve(cls, name: str) -> Type["PromptFormatter"]: + if name not in cls._REGISTERED_FORMATTERS: + raise RuntimeError( + f"Unknown prompt formatter: '{name}' (known formats: {', '.join(cls._REGISTERED_FORMATTERS.keys())})" + ) + return cls._REGISTERED_FORMATTERS[name] + + @classmethod + @lru_cache(1) + def get_roles(cls) -> list[str]: + return list(cls.TEMPLATE.keys()) + + @classmethod + def get_slots(cls, role: str) -> dict[str, Modality]: + # returns a copy to avoid accidential mutation of a global object by the user + return cls.TEMPLATE[role].get("slots", {}).copy() + + @classmethod + def get_template(cls, role: str) -> str: + return cls.TEMPLATE[role]["template"] + + def get_default_dialog_slots(self) -> list[dict]: + """ + Returns a list of dialog turns that can be used as a skeleton to fill with actual slot values. + If ``PromptFormatter`` was initialized with ``defaults`` argument, this method will return the + defaults. Otherwise, every slot is pre-filled with ``None``. + """ + + def _get_default_for_role(role: str) -> dict: + for turn in self._defaults: + if turn["role"] == role: + return turn + return {} + + return [ + { + "role": role, + "slots": { + slot: _get_default_for_role(role).get("slots", {}).get(slot) for slot in self.get_slots(role) + }, + } + for role in self.get_roles() + if role != self.OUTPUT_ROLE + ] + + def encode_turn( + self, prompt_template: str, expected_slots: dict[str, Modality], slot_values: dict[str, Any] + ) -> list[int]: + prompt = prompt_template + for slot in expected_slots: + # For the final substitution of 'slot' in the template we have to mangle it to '|slot|' anyway, + # but 'slot' form enables to use valid python identifiers as **kwargs + # for passing slots around in user functions. + value = slot_values.get(slot) + assert value is not None, f"Missing required {slot=} in {slot_values=} for {prompt_template=}" + prompt = prompt.replace(_mangled(slot), value) + return self._apply_tokenizer(prompt, lang=slot_values.get(self.PROMPT_LANGUAGE_SLOT)) + + def encode_dialog(self, turns: list[dict]) -> dict[str, torch.Tensor]: + assert len(turns) > 0, "Empty dialog is not supported." + roles = self.get_roles() + + turn_tokens = [] + turn_token_counts = [] + turn_mask_values = [] + + if "preamble" in self.TEMPLATE: + preamble_turns = [idx for idx, t in enumerate(turns) if t["role"] == "preamble"] + if not preamble_turns: + turns = [{"role": "preamble", **self.TEMPLATE["preamble"]}] + turns + else: + assert ( + len(preamble_turns) == 1 and preamble_turns[0] == 0 + ), f"Preamble can only be presented at turn 0, but we found preamble turns at indexes {preamble_turns}." + + for turn in turns: + assert "role" in turn, f"A turn must have have a 'role' key. We received {turn=}" + role = turn["role"] + assert role in roles, f"Found turn with {role=}, but availables roles are {roles}" + expected_slots = self.get_slots(role) + slot_values = turn.get("slots", {}) + if expected_slots: + assert ( + slot_values + ), f"A turn for role {role} must have have a non-empty value under 'slots' key. We received {turn=}" + self._validate_slot_values(expected_slots, slot_values) + template = self.get_template(role) + tokens = self.encode_turn(template, expected_slots, slot_values) + turn_tokens.extend(tokens) + turn_token_counts.append(len(tokens)) + turn_mask_values.append(role == self.OUTPUT_ROLE) + + ans = {"input_ids": torch.tensor(turn_tokens, dtype=torch.long)} + if turn_mask_values[-1]: + # The last turn comes from OUTPUT_ROLE, i.e. it's a response from the system. + # This indicates it's a training example for which we provide context/answer/mask. + ans["context_ids"] = ans["input_ids"][: -turn_token_counts[-1]] + ans["answer_ids"] = ans["input_ids"][-turn_token_counts[-1] :] + ans["mask"] = torch.tensor( + [ + turn_mask_values[turn_idx] + for turn_idx, turn_len in enumerate(turn_token_counts) + for _ in range(turn_len) + ], + dtype=torch.bool, + ) + else: + ans["context_ids"] = ans["input_ids"] # context == input for inference + return ans + + def _apply_tokenizer(self, text: str, lang: str | None = None) -> list[int]: + # Check if the tokenizer is aggregate and perform extra checks. + is_agg = isinstance(self.tokenizer, AggregateTokenizer) + if is_agg: + assert lang is not None, ( + f"Missing key '{self.PROMPT_LANGUAGE_SLOT}' in slot_values -- cannot resolve " + f"the correct sub-tokenizer in the aggregate tokenizer." + ) + + # Strip bos/eos if present and remember to apply them later. + has_bos = text.startswith(BOS_SLOT) + has_eos = text.endswith(EOS_SLOT) + if has_bos: + text = text[len(BOS_SLOT) :] + if has_eos: + text = text[: -len(EOS_SLOT)] + + # Tokenize, selecting the right API depending on aggregate/normal tokenizer. + if is_agg: + tokens = self.tokenizer.text_to_ids(text, lang) + else: + tokens = self.tokenizer.text_to_ids(text) + + # Lazily look up bos/eos and apply them. Lazy has the advantage that if a tokenizer + # doesn't define bos/eos and the prompt format does not request them, everything just works. + if has_eos: + eos_id = self.tokenizer.get_eos(lang) if is_agg else self.tokenizer.eos + tokens.append(eos_id) + if has_bos: + bos_id = self.tokenizer.get_bos(lang) if is_agg else self.tokenizer.bos + tokens = [bos_id] + tokens + + return tokens + + def _validate_slot_values(self, expected: dict[str, Modality], received: dict[str, Any]) -> None: + missing = set(expected) - set(received) + assert not missing, f"The following slot values were not provided: {missing}" + for slot in expected: + expected_modality = expected[slot] + value = received[slot] + assert expected_modality.matches( + value + ), f"{slot=} received {value=} which does not match modality {expected_modality}" + + def _validate_defaults(self): + if not self._defaults: + return + + err = "Error in default prompt definition:" + assert isinstance(self._defaults, list) + for turn in self._defaults: + assert isinstance(turn, dict) + assert "role" in turn, f"{err} Missing required 'role' key. We received {turn=}" + role = turn["role"] + assert role in self.get_roles(), ( + f"{err} Invalid {role=} in {turn=} - " f"supported roles are: {self.get_roles()}." + ) + if expected_slots := self.get_slots(role): + assert "slots" in turn, ( + f"{err} Missing required 'slots' key in {turn=} - " + f"we expected the following slots to be provided: {expected_slots}." + ) + for slot in turn["slots"]: + assert slot in expected_slots, ( + f"{err} Invalid {slot=} in {turn=}. " + f"The following slots are supported for {role=}: {expected_slots}" + ) + + +def _mangled(slot: str) -> str: + if not (slot[0] == "|" and slot[-1] == "|"): + return f"|{slot}|" + return slot + + +def _unmangled(slot: str) -> str: + if slot[0] == "|" and slot[-1] == "|": + return slot[1:-1] + return slot diff --git a/nemo/collections/common/prompts/gemma.py b/nemo/collections/common/prompts/gemma.py new file mode 100644 index 000000000000..e3b81c848a3e --- /dev/null +++ b/nemo/collections/common/prompts/gemma.py @@ -0,0 +1,29 @@ +""" +Implemented following the guide at https://www.promptingguide.ai/models/gemma#gemma-7b-prompt-format +""" + +from nemo.collections.common.prompts.formatter import Modality, PromptFormatter + +GEMMA_BOS = "" +GEMMA_END_OF_TURN = "" +GEMMA_NL = "\n\n" + + +class GemmaPromptFormatter(PromptFormatter): + NAME = "gemma" + OUTPUT_ROLE = "assistant" + TEMPLATE = { + "user": { + "template": f"{GEMMA_BOS}user\n|message|{GEMMA_END_OF_TURN}\n{GEMMA_BOS}model\n", + "slots": { + "message": Modality.Text, + }, + }, + OUTPUT_ROLE: { + # Note: that trailing NL is bothering me. + "template": f"|message|{GEMMA_END_OF_TURN}\n", + "slots": { + "message": Modality.Text, + }, + }, + } diff --git a/nemo/collections/common/prompts/llama.py b/nemo/collections/common/prompts/llama.py new file mode 100644 index 000000000000..fdaccfaa846e --- /dev/null +++ b/nemo/collections/common/prompts/llama.py @@ -0,0 +1,72 @@ +from nemo.collections.common.prompts.formatter import BOS_SLOT, EOS_SLOT, Modality, PromptFormatter + + +class Llama2PromptFormatter(PromptFormatter): + """ + This template has been validated to provide identical tokenized results to the official code + in https://github.com/meta-llama/llama/blob/main/llama/generation.py + """ + + NAME = "llama2" + OUTPUT_ROLE = "assistant" + TEMPLATE = { + "system_and_user": { + "template": f"{BOS_SLOT}[INST] <>\n|system|\n<>\n\n|message| [/INST]", + "slots": { + "system": Modality.Text, + "message": Modality.Text, + }, + }, + "user": { + "template": "|bos|[INST] |message| [/INST]", + "slots": { + "message": Modality.Text, + }, + }, + OUTPUT_ROLE: { + "template": f"|message| {EOS_SLOT}", + "slots": { + "message": Modality.Text, + }, + }, + } + + +LLAMA3_BOS = "<|begin_of_text|>" +LLAMA3_HEADER_BEGIN = "<|start_header_id|>" +LLAMA3_HEADER_END = "<|end_header_id|>" +LLAMA3_END_OF_TURN = "<|eot_id|>" +LLAMA3_NL = "\n\n" + + +class Llama3PromptFormatter(PromptFormatter): + """ + Implemented following the code at: + https://github.com/meta-llama/llama3/blob/main/llama/test_tokenizer.py#L56 + """ + + NAME = "llama3" + OUTPUT_ROLE = "assistant" + TEMPLATE = { + "preamble": { + "template": LLAMA3_BOS, + }, + "system": { + "template": f"{LLAMA3_HEADER_BEGIN}system{LLAMA3_HEADER_END}{LLAMA3_NL}|message|{LLAMA3_END_OF_TURN}", + "slots": { + "message": Modality.Text, + }, + }, + "user": { + "template": f"{LLAMA3_HEADER_BEGIN}user{LLAMA3_HEADER_END}{LLAMA3_NL}|message|{LLAMA3_END_OF_TURN}", + "slots": { + "message": Modality.Text, + }, + }, + OUTPUT_ROLE: { + "template": f"{LLAMA3_HEADER_BEGIN}assistant{LLAMA3_HEADER_END}{LLAMA3_NL}|message|{LLAMA3_END_OF_TURN}", + "slots": { + "message": Modality.Text, + }, + }, + } diff --git a/nemo/collections/common/prompts/mistral.py b/nemo/collections/common/prompts/mistral.py new file mode 100644 index 000000000000..e882ac5973b1 --- /dev/null +++ b/nemo/collections/common/prompts/mistral.py @@ -0,0 +1,33 @@ +""" +Implemented following the guide at https://www.promptingguide.ai/models/mistral-7b#chat-template-for-mistral-7b-instruct +""" + +from nemo.collections.common.prompts.formatter import Modality, PromptFormatter + +MISTRAL_BOS = "" +MISTRAL_PROMPT_BEGIN = "[INST]" +MISTRAL_PROMPT_END = "[/INST]" +MISTRAL_END_OF_TURN = "" +MISTRAL_NL = "\n\n" + + +class MistralPromptFormatter(PromptFormatter): + NAME = "mistral" + OUTPUT_ROLE = "assistant" + TEMPLATE = { + "preamble": { + "template": MISTRAL_BOS, + }, + "user": { + "template": f"{MISTRAL_PROMPT_BEGIN} |message| {MISTRAL_PROMPT_END} ", + "slots": { + "message": Modality.Text, + }, + }, + OUTPUT_ROLE: { + "template": f"|message|{MISTRAL_END_OF_TURN}", + "slots": { + "message": Modality.Text, + }, + }, + } diff --git a/nemo/collections/common/prompts/phi2.py b/nemo/collections/common/prompts/phi2.py new file mode 100644 index 000000000000..67dad8d5dd82 --- /dev/null +++ b/nemo/collections/common/prompts/phi2.py @@ -0,0 +1,62 @@ +""" +Implemented following the guide at https://www.promptingguide.ai/models/phi-2#phi-2-usage +""" + +from nemo.collections.common.prompts.formatter import Modality, PromptFormatter + + +class Phi2QAPromptFormatter(PromptFormatter): + NAME = "phi2_qa" + OUTPUT_ROLE = "assistant" + TEMPLATE = { + "user": { + "template": f"Instruct: |message|\nOutput: ", + "slots": { + "message": Modality.Text, + }, + }, + OUTPUT_ROLE: { + "template": f"|message|", + "slots": { + "message": Modality.Text, + }, + }, + } + + +class Phi2ChatPromptFormatter(PromptFormatter): + NAME = "phi2_chat" + OUTPUT_ROLE = "assistant" + TEMPLATE = { + "user": { + "template": f"Human: |message|\nAI: ", + "slots": { + "message": Modality.Text, + }, + }, + OUTPUT_ROLE: { + "template": f"|message|", + "slots": { + "message": Modality.Text, + }, + }, + } + + +class Phi2CodePromptFormatter(PromptFormatter): + NAME = "phi2_code" + OUTPUT_ROLE = "assistant" + TEMPLATE = { + "user": { + "template": f"|message|\n", + "slots": { + "message": Modality.Text, + }, + }, + OUTPUT_ROLE: { + "template": f"|message|", + "slots": { + "message": Modality.Text, + }, + }, + } diff --git a/nemo/collections/common/tokenizers/aggregate_tokenizer.py b/nemo/collections/common/tokenizers/aggregate_tokenizer.py index 9c003c37525a..66ec28ebda4d 100644 --- a/nemo/collections/common/tokenizers/aggregate_tokenizer.py +++ b/nemo/collections/common/tokenizers/aggregate_tokenizer.py @@ -15,6 +15,7 @@ from typing import Dict, List, Union import numpy as np +import torch from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.utils import logging @@ -124,7 +125,7 @@ def tokens_to_text(self, tokens, lang_id): return tokenizer.decode_pieces(tokens) def ids_to_text(self, ids): - if isinstance(ids, np.ndarray): + if isinstance(ids, (np.ndarray, torch.Tensor)): ids = ids.tolist() tokens = [] @@ -224,6 +225,12 @@ def tokens_to_ids(self, tokens: Union[str, List[str]], langs: Union[str, List[st ids.append(self.token_to_id(token, lang_id)) return ids + def get_bos(self, lang_id: str) -> int: + return self.tokenizers_dict[lang_id].bos + self.token_id_offset[lang_id] + + def get_eos(self, lang_id: str) -> int: + return self.tokenizers_dict[lang_id].eos + self.token_id_offset[lang_id] + @property def vocab(self): return self.vocabulary diff --git a/nemo/collections/common/tokenizers/canary_tokenizer.py b/nemo/collections/common/tokenizers/canary_tokenizer.py index aed95c1f9312..6adcdd8cf734 100644 --- a/nemo/collections/common/tokenizers/canary_tokenizer.py +++ b/nemo/collections/common/tokenizers/canary_tokenizer.py @@ -24,7 +24,15 @@ __all__ = ['CanaryTokenizer'] # Default tokens for compatibility with Canary. -DEFAULT_TOKENS = ["<|nospeech|>", "", "<|endoftext|>", "<|startoftranscript|>", "<|pnc|>", "<|nopnc|>"] +CANARY_BOS = "<|startoftranscript|>" +CANARY_EOS = "<|endoftext|>" +CANARY_PAD = "" +CANARY_NOSPEECH = "<|nospeech|>" +CANARY_PNC = "<|pnc|>" +CANARY_NOPNC = "<|nopnc|>" +DEFAULT_TOKENS = [CANARY_NOSPEECH, CANARY_PAD, CANARY_EOS, CANARY_BOS, CANARY_PNC, CANARY_NOPNC] + +CANARY_SPECIAL_TOKENIZER = "spl_tokens" class CanaryTokenizer(AggregateTokenizer): @@ -37,26 +45,51 @@ def __init__(self, tokenizers: Dict): # for easy access of special tokens self.special_tokens = {} - for special in tokenizers['spl_tokens'].vocab: + for special in tokenizers[CANARY_SPECIAL_TOKENIZER].vocab: # Search for special prompting tokens - if (special.startswith("<|") and special.endswith("|>")) or special == "": - self.special_tokens[special] = self.token_to_id(special, lang_id='spl_tokens') + if (special.startswith("<|") and special.endswith("|>")) or special == CANARY_PAD: + self.special_tokens[special] = self.token_to_id(special, lang_id=CANARY_SPECIAL_TOKENIZER) @cached_property def eos_id(self) -> int: - return self.special_tokens["<|endoftext|>"] + return self.special_tokens[CANARY_EOS] @cached_property def bos_id(self) -> int: - return self.special_tokens["<|startoftranscript|>"] + return self.special_tokens[CANARY_BOS] @cached_property def nospeech_id(self) -> int: - return self.special_tokens["<|nospeech|>"] + return self.special_tokens[CANARY_NOSPEECH] @cached_property def pad_id(self) -> int: - return self.special_tokens[""] + return self.special_tokens[CANARY_PAD] + + def text_to_ids(self, text, lang_id) -> list[int]: + if lang_id == CANARY_SPECIAL_TOKENIZER: + return self._tokenize_special_prompt(text) + if text.endswith(CANARY_EOS): + return super().text_to_ids(text[: -len(CANARY_EOS)], lang_id) + [self.eos_id] + return super().text_to_ids(text[-len(CANARY_EOS) :], lang_id) + + def _tokenize_special_prompt(self, text: str) -> list[int]: + """ + Tokenize the input special prompt of the following schema: + + <|startoftranscript|><|source_lang|><|taskname|><|target_lang|><|pnc|> + + Required because otherwise self.text_to_ids() returns a different result than what Canary had been trained with. + """ + ans = [] + assert text.count('>') == 5, f"Expected exactly 5 special tokens in Canary's prompt, got: {text}." + assert text.startswith(CANARY_BOS), text + for _ in range(5): + token = text[: text.find(">") + 1] + ans.append(self.special_tokens[token]) + text = text[len(token) :] + assert len(text) == 0, text + return ans def spl_token_to_id(self, token): if token_id := self.special_tokens.get(f"<|{token}|>", None): diff --git a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py index aed05673f6fa..4a47f0e49b1e 100644 --- a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py +++ b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py @@ -17,6 +17,7 @@ import numpy as np import sentencepiece +import torch from nemo.collections.common.parts.utils import if_exist from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec @@ -127,7 +128,7 @@ def tokens_to_text(self, tokens): return self.tokenizer.decode_pieces(tokens) def ids_to_text(self, ids): - if isinstance(ids, np.ndarray): + if isinstance(ids, (np.ndarray, torch.Tensor)): ids = ids.tolist() if self.legacy: diff --git a/tests/collections/__init__.py b/tests/collections/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/collections/asr/test_asr_multitask_model_bpe.py b/tests/collections/asr/test_asr_multitask_model_bpe.py index d250fbcf74a1..986df09deacb 100644 --- a/tests/collections/asr/test_asr_multitask_model_bpe.py +++ b/tests/collections/asr/test_asr_multitask_model_bpe.py @@ -80,9 +80,18 @@ def asr_model(test_data_dir): 'dir': None, 'type': 'agg', 'langs': { - 'spl_tokens': {'dir': os.path.join(test_data_dir, "asr", "tokenizers", "canary"), 'type': 'bpe',}, - 'en': {'dir': os.path.join(test_data_dir, "asr", "tokenizers", "an4_wpe_128"), 'type': 'wpe',}, - 'de': {'dir': os.path.join(test_data_dir, "asr", "tokenizers", "an4_wpe_128"), 'type': 'wpe',}, + 'spl_tokens': { + 'dir': os.path.join(test_data_dir, "asr", "tokenizers", "canary"), + 'type': 'bpe', + }, + 'en': { + 'dir': os.path.join(test_data_dir, "asr", "tokenizers", "an4_wpe_128"), + 'type': 'wpe', + }, + 'de': { + 'dir': os.path.join(test_data_dir, "asr", "tokenizers", "an4_wpe_128"), + 'type': 'wpe', + }, }, 'custom_tokenizer': { '_target_': 'nemo.collections.common.tokenizers.canary_tokenizer.CanaryTokenizer', @@ -98,6 +107,9 @@ def asr_model(test_data_dir): modelConfig = DictConfig( { 'prompt_format': 'canary', + 'prompt_defaults': [ + {"role": "user", "slots": {"source_lang": "en", "target_lang": "en", "task": "asr", "pnc": "yes"}} + ], 'sample_rate': 16000, 'preprocessor': DictConfig(preprocessor), 'model_defaults': DictConfig(model_defaults), @@ -304,10 +316,9 @@ def test_transcribe_tensor(self, asr_model, test_data_dir): audio, sr = sf.read(audio_file, dtype='float32') # Numpy array test - with pytest.raises(NotImplementedError): - outputs = asr_model.transcribe(audio, batch_size=1) - # assert len(outputs) == 1 - # assert isinstance(outputs[0], str) + outputs = asr_model.transcribe(audio, batch_size=1) + assert len(outputs) == 1 + assert isinstance(outputs[0], str) @pytest.mark.unit def test_build_tokenizer(self, asr_model, test_data_dir): diff --git a/tests/collections/asr/test_custom_tokenizer.py b/tests/collections/asr/test_custom_tokenizer.py index 5a033045b709..61692061661f 100644 --- a/tests/collections/asr/test_custom_tokenizer.py +++ b/tests/collections/asr/test_custom_tokenizer.py @@ -67,7 +67,9 @@ class DummyModel(ASRBPEMixin, Serialization): "spl_tokens": {"dir": special_tokenizer_path, "type": "bpe"}, "en": {"dir": lang_tokenizer_path, "type": "bpe"}, }, - "custom_tokenizer": {"_target_": "nemo.collections.common.tokenizers.canary_tokenizer.CanaryTokenizer",}, + "custom_tokenizer": { + "_target_": "nemo.collections.common.tokenizers.canary_tokenizer.CanaryTokenizer", + }, } ) model._setup_aggregate_tokenizer(config) @@ -83,5 +85,11 @@ class DummyModel(ASRBPEMixin, Serialization): assert isinstance(tokenizer.tokenizers_dict["en"], SentencePieceTokenizer) assert tokenizer.tokenizers_dict["en"].vocab_size == 6 - assert tokenizer.text_to_ids("<|startoftranscript|>", lang_id="spl_tokens") == [13, 4] # "_" comes first + assert tokenizer.text_to_ids("<|startoftranscript|><|en|><|asr|><|en|><|pnc|>", lang_id="spl_tokens") == [ + 4, + 9, + 7, + 9, + 5, + ] assert tokenizer.text_to_ids("a", lang_id="en") == [14 + 1, 14 + 2] diff --git a/tests/collections/common/prompt_formatters/conftest.py b/tests/collections/common/prompt_formatters/conftest.py new file mode 100644 index 000000000000..e18f1072af24 --- /dev/null +++ b/tests/collections/common/prompt_formatters/conftest.py @@ -0,0 +1,51 @@ +import pytest + +from nemo.collections.common.tokenizers import CanaryTokenizer, SentencePieceTokenizer +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import create_spt_model + +# Note: We don't really define special tokens for this test so every 'special token' +# will be represented as a number of regular tokens. +TOKENIZER_TRAIN_TEXT = """ +Example system message. +Example user message. +Example assistant message. +TEST +[INST] +[/INST] + + +<> +<> +User: Assistant: +user model +Instruct Output +\n\n + +<| +|> +<|en|> <|de|> <|fr|> <|es|> <|transcribe|> <|translate|> <|pnc|> <|nopnc|> <|startoftranscript|> <|endoftext|> +Feel free to add new tokens for your own tests!? +But know that if you do so, you may need to update the token IDs in the existing tests! +So, it might be a good idea to create a new tokenizer instead when adding new prompt formats. +""" + + +@pytest.fixture(scope="session") +def bpe_tokenizer(tmp_path_factory): + tmpdir = tmp_path_factory.mktemp("bpe_tokenizer") + text_path = tmpdir / "text.txt" + text_path.write_text(TOKENIZER_TRAIN_TEXT) + create_spt_model(str(text_path), vocab_size=512, sample_size=-1, do_lower_case=False, output_dir=str(tmpdir)) + return SentencePieceTokenizer(str(tmpdir / "tokenizer.model")) + + +@pytest.fixture(scope="session") +def canary_tokenizer(bpe_tokenizer, tmp_path_factory): + tmpdir = tmp_path_factory.mktemp("spl_tokens") + spl_tokens = CanaryTokenizer.build_special_tokenizer(["transcribe", "en"], tmpdir) + return CanaryTokenizer( + tokenizers={ + "spl_tokens": spl_tokens, + "en": bpe_tokenizer, + } + ) diff --git a/tests/collections/common/prompt_formatters/test_canary_prompt_formatter.py b/tests/collections/common/prompt_formatters/test_canary_prompt_formatter.py new file mode 100644 index 000000000000..ff786766b246 --- /dev/null +++ b/tests/collections/common/prompt_formatters/test_canary_prompt_formatter.py @@ -0,0 +1,50 @@ +from nemo.collections.common.prompts.canary import CanaryPromptFormatter + + +def test_canary_prompt_formatter_training(canary_tokenizer): + formatter = CanaryPromptFormatter(canary_tokenizer) + ans = formatter.encode_dialog( + [ + { + "role": "user", + "slots": { + "source_lang": "<|en|>", + "target_lang": "<|en|>", + "task": "<|transcribe|>", + "pnc": "<|pnc|>", + "prompt_language": "spl_tokens", + }, + }, + {"role": "assistant", "slots": {"text": "TEST", "prompt_language": "en"}}, + ] + ) + assert set(ans) == {"input_ids", "context_ids", "answer_ids", "mask"} + # fmt: off + assert ans["input_ids"].tolist() == [4, 8, 7, 8, 5, 11, 91, 30, 40, 3] + assert ans["context_ids"].tolist() == [4, 8, 7, 8, 5] + assert ans["answer_ids"].tolist() == [11, 91, 30, 40, 3] + assert ans["mask"].tolist() == [False] * 5 + [True] * 5 + # fmt: on + + +def test_canary_prompt_formatter_inference(canary_tokenizer): + formatter = CanaryPromptFormatter(canary_tokenizer) + ans = formatter.encode_dialog( + [ + { + "role": "user", + "slots": { + "source_lang": "<|en|>", + "target_lang": "<|en|>", + "task": "<|transcribe|>", + "pnc": "<|pnc|>", + "prompt_language": "spl_tokens", + }, + }, + ] + ) + assert set(ans) == {"input_ids", "context_ids"} + # fmt: off + assert ans["input_ids"].tolist() == ans["context_ids"].tolist() + assert ans["input_ids"].tolist() == [4, 8, 7, 8, 5] + # fmt: on diff --git a/tests/collections/common/prompt_formatters/test_gemma_prompt_formatter.py b/tests/collections/common/prompt_formatters/test_gemma_prompt_formatter.py new file mode 100644 index 000000000000..be1f6de1a873 --- /dev/null +++ b/tests/collections/common/prompt_formatters/test_gemma_prompt_formatter.py @@ -0,0 +1,40 @@ +from nemo.collections.common.prompts.gemma import GemmaPromptFormatter + + +def test_gemma_prompt_formatter_training(bpe_tokenizer): + formatter = GemmaPromptFormatter(bpe_tokenizer) + ans = formatter.encode_dialog( + [ + {"role": "user", "slots": {"message": "TEST"}}, + {"role": "assistant", "slots": {"message": "TEST"}}, + ] + ) + assert set(ans) == {"input_ids", "context_ids", "answer_ids", "mask"} + # fmt: off + assert ans["input_ids"].tolist() == [ 21, 53, 18, 26, 18, 6, 60, 9, 7, 75, 31, 1, 81, 20, + 30, 104, 59, 18, 26, 18, 6, 60, 9, 7, 21, 53, 18, 26, + 18, 6, 60, 9, 7, 73, 61, 69, 1, 81, 20, 30, 104, 59, + 18, 26, 18, 6, 60, 9, 7] + assert ans["context_ids"].tolist() == [ 21, 53, 18, 26, 18, 6, 60, 9, 7, 75, 31, 1, 81, 20, + 30, 104, 59, 18, 26, 18, 6, 60, 9, 7, 21, 53, 18, 26, + 18, 6, 60, 9, 7, 73, 61, 69] + assert ans["answer_ids"].tolist() == [1, 81, 20, 30, 104, 59, + 18, 26, 18, 6, 60, 9, 7] + assert ans["mask"].tolist() == [False] * 36 + [True] * 13 + # fmt: on + + +def test_gemma_prompt_formatter_inference(bpe_tokenizer): + formatter = GemmaPromptFormatter(bpe_tokenizer) + ans = formatter.encode_dialog( + [ + {"role": "user", "slots": {"message": "TEST"}}, + ] + ) + assert set(ans) == {"input_ids", "context_ids"} + # fmt: off + assert ans["input_ids"].tolist() == ans["context_ids"].tolist() + assert ans["input_ids"].tolist() == [ 21, 53, 18, 26, 18, 6, 60, 9, 7, 75, 31, 1, 81, 20, + 30, 104, 59, 18, 26, 18, 6, 60, 9, 7, 21, 53, 18, 26, + 18, 6, 60, 9, 7, 73, 61, 69] + # fmt: on diff --git a/tests/collections/common/prompt_formatters/test_llama2_prompt_formatter.py b/tests/collections/common/prompt_formatters/test_llama2_prompt_formatter.py new file mode 100644 index 000000000000..9636dd31c768 --- /dev/null +++ b/tests/collections/common/prompt_formatters/test_llama2_prompt_formatter.py @@ -0,0 +1,63 @@ +from nemo.collections.common.prompts.llama import Llama2PromptFormatter + + +def test_llama2_prompt_formatter_training(bpe_tokenizer): + formatter = Llama2PromptFormatter(bpe_tokenizer) + ans = formatter.encode_dialog( + [ + {"role": "user", "slots": {"message": "TEST"}}, + {"role": "assistant", "slots": {"message": "TEST"}}, + ] + ) + assert set(ans) == {"input_ids", "context_ids", "answer_ids", "mask"} + # fmt: off + assert ans["input_ids"].tolist() == [-1, 54, 42, 49, 30, 50, 1, 81, 20, 30, 54, 72, 42, 49, 30, 50, 1, 81, 20, 30, -1] + assert ans["context_ids"].tolist() == [-1, 54, 42, 49, 30, 50, 1, 81, 20, 30, 54, 72, 42, 49, 30, 50] + assert ans["answer_ids"].tolist() == [1, 81, 20, 30, -1] + assert ans["mask"].tolist() == [False] * 16 + [True] * 5 + # fmt: on + + +def test_llama2_prompt_formatter_inference(bpe_tokenizer): + formatter = Llama2PromptFormatter(bpe_tokenizer) + ans = formatter.encode_dialog( + [ + {"role": "user", "slots": {"message": "TEST"}}, + ] + ) + assert set(ans) == {"input_ids", "context_ids"} + # fmt: off + assert ans["input_ids"].tolist() == ans["context_ids"].tolist() + assert ans["input_ids"].tolist() == [-1, 54, 42, 49, 30, 50, 1, 81, 20, 30, 54, 72, 42, 49, 30, 50] + # fmt: on + + +def test_llama2_prompt_formatter_training_with_system(bpe_tokenizer): + formatter = Llama2PromptFormatter(bpe_tokenizer) + ans = formatter.encode_dialog( + [ + {"role": "system_and_user", "slots": {"system": "TEST", "message": "TEST"}}, + {"role": "assistant", "slots": {"message": "TEST"}}, + ] + ) + assert set(ans) == {"input_ids", "context_ids", "answer_ids", "mask"} + # fmt: off + assert ans["input_ids"].tolist() == [-1, 54, 42, 49, 30, 50, 77, 13, 45, 13, 7, 7, 1, 81, 20, 30, 21, 66, 13, 45, 13, 7, 7, 1, 81, 20, 30, 54, 72, 42, 49, 30, 50, 1, 81, 20, 30, -1] + assert ans["context_ids"].tolist() == [-1, 54, 42, 49, 30, 50, 77, 13, 45, 13, 7, 7, 1, 81, 20, 30, 21, 66, 13, 45, 13, 7, 7, 1, 81, 20, 30, 54, 72, 42, 49, 30, 50] + assert ans["answer_ids"].tolist() == [1, 81, 20, 30, -1] + assert ans["mask"].tolist() == [False] * 33 + [True] * 5 + # fmt: on + + +def test_llama2_prompt_formatter_inference_with_system(bpe_tokenizer): + formatter = Llama2PromptFormatter(bpe_tokenizer) + ans = formatter.encode_dialog( + [ + {"role": "system_and_user", "slots": {"system": "TEST", "message": "TEST"}}, + ] + ) + assert set(ans) == {"input_ids", "context_ids"} + # fmt: off + assert ans["input_ids"].tolist() == ans["context_ids"].tolist() + assert ans["input_ids"].tolist() == [-1, 54, 42, 49, 30, 50, 77, 13, 45, 13, 7, 7, 1, 81, 20, 30, 21, 66, 13, 45, 13, 7, 7, 1, 81, 20, 30, 54, 72, 42, 49, 30, 50] + # fmt: on diff --git a/tests/collections/common/prompt_formatters/test_mistral_prompt_formatter.py b/tests/collections/common/prompt_formatters/test_mistral_prompt_formatter.py new file mode 100644 index 000000000000..edc00d426952 --- /dev/null +++ b/tests/collections/common/prompt_formatters/test_mistral_prompt_formatter.py @@ -0,0 +1,32 @@ +from nemo.collections.common.prompts.mistral import MistralPromptFormatter + + +def test_mistral_prompt_formatter_training(bpe_tokenizer): + formatter = MistralPromptFormatter(bpe_tokenizer) + ans = formatter.encode_dialog( + [ + {"role": "user", "slots": {"message": "TEST"}}, + {"role": "assistant", "slots": {"message": "TEST"}}, + ] + ) + assert set(ans) == {"input_ids", "context_ids", "answer_ids", "mask"} + # fmt: off + assert ans["input_ids"].tolist() == [21, 8, 7, 54, 42, 49, 30, 50, 1, 81, 20, 30, 54, 72, 42, 49, 30, 50, 1, 81, 20, 30, 66, 8, 7] + assert ans["context_ids"].tolist() == [21, 8, 7, 54, 42, 49, 30, 50, 1, 81, 20, 30, 54, 72, 42, 49, 30, 50] + assert ans["answer_ids"].tolist() == [1, 81, 20, 30, 66, 8, 7] + assert ans["mask"].tolist() == [False] * 18 + [True] * 7 + # fmt: on + + +def test_mistral_prompt_formatter_inference(bpe_tokenizer): + formatter = MistralPromptFormatter(bpe_tokenizer) + ans = formatter.encode_dialog( + [ + {"role": "user", "slots": {"message": "TEST"}}, + ] + ) + assert set(ans) == {"input_ids", "context_ids"} + # fmt: off + assert ans["input_ids"].tolist() == ans["context_ids"].tolist() + assert ans["input_ids"].tolist() == [21, 8, 7, 54, 42, 49, 30, 50, 1, 81, 20, 30, 54, 72, 42, 49, 30, 50] + # fmt: on diff --git a/tests/collections/common/prompt_formatters/test_prompt_formatter_api.py b/tests/collections/common/prompt_formatters/test_prompt_formatter_api.py new file mode 100644 index 000000000000..26ade7da1415 --- /dev/null +++ b/tests/collections/common/prompt_formatters/test_prompt_formatter_api.py @@ -0,0 +1,147 @@ +import pytest + +from nemo.collections.common.prompts.canary import PromptFormatter +from nemo.collections.common.prompts.formatter import Modality + + +class _DummyPromptFormatter(PromptFormatter): + NAME = "_dummy_test_formatter" + TEMPLATE = { + "user": {"template": "|text|", "slots": {"text": Modality.Text}}, + "assistant": {"template": "|text|", "slots": {"text": Modality.Text}}, + } + OUTPUT_ROLE = "assistant" + + +def test_prompt_formatter_empty_dialog_exception(bpe_tokenizer): + formatter = _DummyPromptFormatter(bpe_tokenizer) + with pytest.raises(AssertionError): + formatter.encode_dialog([]) + + +def test_prompt_formatter_inference(bpe_tokenizer): + formatter = _DummyPromptFormatter(bpe_tokenizer) + ans = formatter.encode_dialog([{"role": "user", "slots": {"text": "hi"}}]) + recovered = bpe_tokenizer.ids_to_text(ans["input_ids"]) + assert recovered == "hi" + + +def test_prompt_formatter_training(bpe_tokenizer): + formatter = _DummyPromptFormatter(bpe_tokenizer) + ans = formatter.encode_dialog( + [ + {"role": "user", "slots": {"text": "hi"}}, + {"role": "assistant", "slots": {"text": "hello"}}, + ] + ) + recovered = bpe_tokenizer.ids_to_text(ans["input_ids"]) + assert recovered == "hi hello", recovered + + +def test_prompt_formatter_missing_role(bpe_tokenizer): + formatter = _DummyPromptFormatter(bpe_tokenizer) + with pytest.raises(AssertionError, match="A turn must have have a 'role' key"): + formatter.encode_dialog([{"slots": {"text": "hi"}}]) + + +def test_prompt_formatter_missing_slots(bpe_tokenizer): + formatter = _DummyPromptFormatter(bpe_tokenizer) + with pytest.raises( + AssertionError, match="A turn for role user must have have a non-empty value under 'slots' key" + ): + formatter.encode_dialog([{"role": "user"}]) + with pytest.raises( + AssertionError, match="A turn for role user must have have a non-empty value under 'slots' key" + ): + formatter.encode_dialog([{"role": "user", "slots": {}}]) + + +def test_prompt_formatter_aggregate_tokenizer(canary_tokenizer): + # Note the 'canary_tokenizer' arg which is an aggregate tokenizer fixture. + formatter = _DummyPromptFormatter(canary_tokenizer) + ans = formatter.encode_dialog( + [ + { + "role": "user", + "slots": { + "text": "hi", + "prompt_language": "en", + }, + } + ] + ) + recovered = canary_tokenizer.ids_to_text(ans["input_ids"]) + assert recovered == " hi" + + +def test_prompt_formatter_aggregate_tokenizer_missing_prompt_language(canary_tokenizer): + # Note the 'canary_tokenizer' arg which is an aggregate tokenizer fixture. + formatter = _DummyPromptFormatter(canary_tokenizer) + + with pytest.raises(AssertionError, match="Missing key 'prompt_language' in slot_values"): + formatter.encode_dialog([{"role": "user", "slots": {"text": "hi"}}]) + + +class _DummyPreamblePromptFormatter(PromptFormatter): + NAME = "_dummy_preamble_test_formatter" + TEMPLATE = { + "preamble": {"template": "TEST"}, + "user": {"template": "|text|", "slots": {"text": Modality.Text}}, + "assistant": {"template": "|text|", "slots": {"text": Modality.Text}}, + } + OUTPUT_ROLE = "assistant" + + +def test_prompt_formatter_preamble_inference(bpe_tokenizer): + formatter = _DummyPreamblePromptFormatter(bpe_tokenizer) + ans = formatter.encode_dialog([{"role": "user", "slots": {"text": "hi"}}]) + recovered = bpe_tokenizer.ids_to_text(ans["input_ids"]) + assert recovered == "TEST hi", recovered + + +def test_prompt_formatter_premble_training(bpe_tokenizer): + formatter = _DummyPreamblePromptFormatter(bpe_tokenizer) + ans = formatter.encode_dialog( + [ + {"role": "user", "slots": {"text": "hi"}}, + {"role": "assistant", "slots": {"text": "hello"}}, + ] + ) + recovered = bpe_tokenizer.ids_to_text(ans["input_ids"]) + assert recovered == "TEST hi hello" + + +def test_prompt_formatter_explicit_preamble(bpe_tokenizer): + formatter = _DummyPreamblePromptFormatter(bpe_tokenizer) + ans = formatter.encode_dialog([{"role": "preamble"}, {"role": "user", "slots": {"text": "hi"}}]) + recovered = bpe_tokenizer.ids_to_text(ans["input_ids"]) + assert recovered == "TEST hi" + + +def test_prompt_formatter_wrong_preamble_excpetions(bpe_tokenizer): + formatter = _DummyPreamblePromptFormatter(bpe_tokenizer) + with pytest.raises(AssertionError): + # Error: 2 preambles + formatter.encode_dialog( + [ + {"role": "preamble"}, + {"role": "preamble"}, + {"role": "user", "slots": {"text": "hi"}}, + ] + ) + with pytest.raises(AssertionError): + # Error: preamble not at the beginning + formatter.encode_dialog( + [ + {"role": "user", "slots": {"text": "hi"}}, + {"role": "preamble"}, + ] + ) + with pytest.raises(AssertionError): + # Error: preamble with slots + formatter.encode_dialog( + [ + {"role": "user", "slots": {"text": "hi"}}, + {"role": "preamble", "slots": {"abc": "abc"}}, + ] + )