Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt formatter API and canary transcribe tensor input support #9206

Merged
merged 22 commits into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
bfbacdc
Apply CanaryPromptFormatter in dataset/inference
pzelasko May 15, 2024
967776b
Working inference with CanaryPromptFormatter
pzelasko May 15, 2024
04fdba9
Minimum working example of Canary.transcribe() with tensors
pzelasko May 15, 2024
1f902ff
training fix
pzelasko May 15, 2024
e86362e
Update to the new 'chat' based prompt formatting API
pzelasko May 21, 2024
41f96f1
Prompt formatters for popular models and partial unit test coverage
pzelasko May 21, 2024
06ff96d
Updated documentation
pzelasko May 21, 2024
71b9191
Improved test coverage + proper preamble support
pzelasko May 22, 2024
5555a4c
Fix usage of PromptFormatter for MT-AED class + fix tokenization/form…
pzelasko May 22, 2024
2350356
Move some canary hacks to canary prompt formatter, improve validation…
pzelasko May 22, 2024
30713b8
aed_model.transcribe(**slots) support, rename all slots to lowercase …
pzelasko May 23, 2024
9334a88
truly generic version
pzelasko May 23, 2024
2f7cd7a
making transcribe_speech.py work prompt slots + syntactic sugar
pzelasko May 23, 2024
3a533ae
update streaming_utils.py
pzelasko May 23, 2024
d6f75f0
Merge branch 'main' into prompt-formatter-and-canary-tensor-dataset
pzelasko May 23, 2024
61f92d8
fix
pzelasko May 23, 2024
9fe28cb
code review: partial
pzelasko May 24, 2024
3f60244
Accept multi-turn, single-turn, and legacy prompt format in transcrib…
pzelasko May 29, 2024
9e13c2e
Address code reviews
pzelasko May 31, 2024
3f9453b
Add support for SPE special tokens bos/eos in prompt templates and en…
pzelasko May 31, 2024
55ac422
Fix tests and add llama2 prompt formatter tests
pzelasko May 31, 2024
43ec9ad
Fix tests
pzelasko May 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 29 additions & 94 deletions nemo/collections/asr/data/audio_to_text_lhotse_prompted.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from lhotse.dataset.collation import collate_vectors

from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper
from nemo.collections.common.prompts.canary import CanaryPromptFormatter, map_manifest_values_to_special_tokens
from nemo.collections.common.tokenizers import CanaryTokenizer, TokenizerSpec
from nemo.collections.common.tokenizers.canary_tokenizer import CANARY_SPECIAL_TOKENIZER


class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -132,7 +134,7 @@ def canary(cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False) -
assert isinstance(
tokenizer._tokenizer, CanaryTokenizer
), "To use 'canary' prompt format, you must use the CanaryTokenizer."
tokenizer = tokenizer._tokenizer
formatter = CanaryPromptFormatter(tokenizer._tokenizer)

tokens, prompts = [], []
for cut in cuts:
Expand All @@ -148,100 +150,33 @@ def canary(cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False) -
f"Please ensure that every utterance in the input manifests contains these keys."
)

# Actual tokenization. If a cut has multiple supervisions, we'll stitch their tokenized texts together.
texts = [sup.text for sup in cut.supervisions]
langs = [sup.language for sup in cut.supervisions]
taskname = cut.custom['taskname']
pnc = cut.custom['pnc']
source_lang = cut.custom['source_lang']
target_lang = cut.custom['target_lang']

tokens.append(canary_prompt(tokenizer, texts, langs, source_lang, target_lang, taskname, pnc))
if inference:
prompts.append(canary_prompt(tokenizer, None, None, source_lang, target_lang, taskname, pnc))
return tokens, prompts


def canary_prompt(
tokenizer: CanaryTokenizer,
text: str | list[str] | None,
language: str | list[str] | None,
source_language: str,
target_language: str,
taskname: str,
pnc: str,
) -> list[int]:
if isinstance(text, str):
text = [text]
if isinstance(language, str):
language = [language]

if text is not None:
try:
tokens = sum((tokenizer.text_to_ids(text_, lang_) for text_, lang_ in zip(text, language)), start=[])
except omegaconf.errors.KeyValidationError as e:
raise ProbablyIncorrectLanguageKeyError(
"We couldn't select the right tokenizer, which could be due to issues with reading "
"the language from the manifest. "
"If you're training, try setting lang_field='' to a different value (probably 'target_lang' or 'lang'). "
"If you're using model.transcribe() directly, please use override_config kwarg to set this. "
"If you're using transcribe_speech.py, use option gt_lang_attr_name='...' "
) from e
else:
tokens = None # create prompt for inference

# bos
prompted_tokens = [tokenizer.bos_id]

if tokens is not None and len(tokens) == 0:
# no speech token
prompted_tokens.append(tokenizer.nospeech_id)
else:
# first, validate the utterance
if source_language is None or target_language is None or taskname is None or pnc is None:
raise RuntimeError(
f"Missing keys provided to prompt: "
f"source_langauge={source_language},\n"
f"target_language={target_language},\n"
f"taskname={taskname},\n"
f"pnc={pnc}\n"
f"Please ensure that every utterance in the input manifests contains these keys."
)

# src_lang_id/no_speech
src_lang_id = tokenizer.spl_token_to_id(source_language)
prompted_tokens.append(src_lang_id)

# task
task = taskname
if task == 'asr' or task == "transcribe":
prompted_tokens.append(tokenizer.spl_token_to_id("transcribe"))
elif task == 's2t_translation' or task == 'ast' or task == "translate":
prompted_tokens.append(tokenizer.spl_token_to_id("translate"))
else:
raise ValueError(f"Unknown task: {task}")

# tgt_lang_id
tgt_lang_id = tokenizer.spl_token_to_id(target_language)
prompted_tokens.append(tgt_lang_id)

# PnC
pnc = f"{pnc}".lower().strip() # to account for bool or str
if pnc in {'yes', 'true'}:
prompted_tokens.append(tokenizer.spl_token_to_id("pnc"))
elif pnc in {'no', 'false'}:
prompted_tokens.append(tokenizer.spl_token_to_id("nopnc"))
else:
raise ValueError(f"Unknown value for key 'pnc': {pnc}")

# text (only in training)
if tokens is not None:
prompted_tokens.extend(tokens)
encoded = formatter.encode_dialog(
turns=[
dict(
role="user",
slots=map_manifest_values_to_special_tokens(
{
"|SOURCE_LANG|": cut.custom['source_lang'],
"|TARGET_LANG|": cut.custom['target_lang'],
"|TASKNAME|": cut.custom['taskname'],
"|PNC|": cut.custom['pnc'],
"|PROMPT_LANGUAGE|": CANARY_SPECIAL_TOKENIZER,
}
),
),
dict(
role="assistant",
slots={
"|TEXT|": ' '.join(s.text for s in cut.supervisions),
"|PROMPT_LANGUAGE|": cut.custom["target_lang"],
},
),
]
)
tokens.append(encoded["input_ids"])
prompts.append(encoded["context_ids"])

# eos (only in training)
if tokens is not None:
prompted_tokens.append(tokenizer.eos_id)
return prompted_tokens
return tokens, prompts
pzelasko marked this conversation as resolved.
Show resolved Hide resolved


class ProbablyIncorrectLanguageKeyError(RuntimeError):
Expand Down
42 changes: 35 additions & 7 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
from nemo.collections.common.metrics import GlobalAverageLossMetric
from nemo.collections.common.parts import transformer_weights_init
from nemo.collections.common.parts.preprocessing.manifest import get_full_path
from nemo.collections.common.prompts.canary import map_manifest_values_to_special_tokens
from nemo.collections.common.prompts.formatter import PromptFormatter
from nemo.core.classes.common import typecheck
from nemo.core.neural_types import (
AudioSignal,
Expand Down Expand Up @@ -156,7 +158,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.transf_encoder = EncDecMultiTaskModel.from_config_dict(transf_encoder_cfg_dict)

# Initialize weights
std_init_range = 1 / self.cfg.model_defaults.lm_enc_hidden ** 0.5
std_init_range = 1 / self.cfg.model_defaults.lm_enc_hidden**0.5
self.transf_encoder.apply(lambda module: transformer_weights_init(module, std_init_range))

transf_decoder_cfg_dict = cfg.transf_decoder
Expand All @@ -182,7 +184,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight

# Initialize weights
std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden ** 0.5
std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden**0.5
self.transf_decoder.apply(lambda module: transformer_weights_init(module, std_init_range))
self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range))

Expand Down Expand Up @@ -347,7 +349,7 @@ def change_vocabulary(
self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight

# Initialize weights of token classifier
std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden ** 0.5
std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden**0.5
self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range))

# Setup Decoding class
Expand Down Expand Up @@ -716,6 +718,8 @@ def _transcribe_on_begin(self, audio, trcfg: MultiTaskTranscriptionConfig):
# Switch model to evaluation mode
self.transf_decoder.freeze()

self.prompt_formatter = PromptFormatter.resolve(self.prompt_format)(self.tokenizer)
pzelasko marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(audio, list):
logging.debug(f"Found 'audio' to be a list of {len(audio)} items.")
logging.debug(f"Assuming each item in 'audio' is a path to audio file.")
Expand All @@ -736,9 +740,6 @@ def _transcribe_on_begin(self, audio, trcfg: MultiTaskTranscriptionConfig):
if hasattr(trcfg, '_internal') and hasattr(trcfg._internal, 'manifest_path'):
trcfg._internal.manifest_filepath = manifest_path

elif isinstance(audio, (np.ndarray, torch.Tensor)):
raise NotImplementedError("Transcribing from numpy or torch tensors is not supported yet.")

def _transcribe_input_manifest_processing(
self, audio_files: List[str], temp_dir: str, trcfg: MultiTaskTranscriptionConfig
) -> Dict[str, Any]:
Expand Down Expand Up @@ -790,7 +791,34 @@ def _transcribe_forward(self, batch: Any, trcfg: MultiTaskTranscriptionConfig):
log_probs, encoded_len, enc_states, enc_mask = self.forward(
input_signal=batch[0], input_signal_length=batch[1]
)
decoder_input_ids = batch[-2].to(trcfg._internal.device)
if len(batch) >= 4:
# Backward compatibility: the decoder_input_ids were provided by the dataloader.
decoder_input_ids = batch[-2].to(trcfg._internal.device)
else:
# The dataloader provided only audio + audio_lens, so we
# are constructing the prompt dynamically using TranscribeConfig.
# TODO: have 'slots' dict/dataclass in trcfg instead of getattr()
# for now to get a POC working, I'm manually mapping the specific existing canary's format
turns = [
dict(
role="user",
slots=map_manifest_values_to_special_tokens(
{
"|TASKNAME|": trcfg.task,
"|SOURCE_LANG|": trcfg.source_lang,
"|TARGET_LANG|": trcfg.target_lang,
"|PNC|": trcfg.pnc,
"|PROMPT_LANGUAGE|": "spl_tokens",
}
),
)
]
decoder_input_ids = (
self.prompt_formatter.encode_dialog(turns=turns)["context_ids"]
.unsqueeze(0)
.repeat(batch[0].shape[0], 1)
.to(trcfg._internal.device)
)
output = dict(
log_probs=log_probs,
encoded_lengths=encoded_len,
Expand Down
73 changes: 54 additions & 19 deletions nemo/collections/asr/parts/utils/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from omegaconf import OmegaConf
from torch.utils.data import DataLoader

from nemo.collections.asr.data.audio_to_text_lhotse_prompted import canary_prompt
from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE
from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder
from nemo.collections.asr.parts.preprocessing.features import normalize_batch
from nemo.collections.asr.parts.utils.audio_utils import get_samples
from nemo.collections.common.prompts.canary import CanaryPromptFormatter, map_manifest_values_to_special_tokens
from nemo.core.classes import IterableDataset
from nemo.core.neural_types import LengthsType, MelSpectrogramType, NeuralType

Expand Down Expand Up @@ -444,7 +444,10 @@ def _convert_buffer_to_features(self):
device = self.asr_model.device
audio_signal = samples.unsqueeze_(0).to(device)
audio_signal_len = torch.Tensor([samples.shape[1]]).to(device)
features, features_len = self.raw_preprocessor(input_signal=audio_signal, length=audio_signal_len,)
features, features_len = self.raw_preprocessor(
input_signal=audio_signal,
length=audio_signal_len,
)
features = features.squeeze()
self._update_feature_buffer(features[:, -self.feature_chunk_len :])

Expand Down Expand Up @@ -479,7 +482,10 @@ def __init__(self, samples, frame_len, preprocessor, device, pad_to_frame_len=Tr
self._feature_frame_len = frame_len / timestep_duration
audio_signal = torch.from_numpy(self._samples).unsqueeze_(0).to(device)
audio_signal_len = torch.Tensor([self._samples.shape[0]]).to(device)
self._features, self._features_len = preprocessor(input_signal=audio_signal, length=audio_signal_len,)
self._features, self._features_len = preprocessor(
input_signal=audio_signal,
length=audio_signal_len,
)
self._features = self._features.squeeze()

def __iter__(self):
Expand Down Expand Up @@ -701,7 +707,12 @@ class for streaming frame-based ASR use reset() method to reset FrameASR's
"""

def __init__(
self, asr_model, frame_len=1.6, total_buffer=4.0, batch_size=4, pad_to_buffer_len=True,
self,
asr_model,
frame_len=1.6,
total_buffer=4.0,
batch_size=4,
pad_to_buffer_len=True,
):
'''
Args:
Expand Down Expand Up @@ -1183,7 +1194,9 @@ def _get_batch_preds(self):
del best_hyp, pred

def transcribe(
self, tokens_per_chunk: int, delay: int,
self,
tokens_per_chunk: int,
delay: int,
):
"""
Performs "middle token" alignment prediction using the buffered audio chunk.
Expand All @@ -1210,7 +1223,12 @@ def transcribe(
ids, toks = self._alignment_decoder(alignment, self.asr_model.tokenizer, self.blank_id)

if len(ids) > 0 and a_idx < signal_end_idx:
self.unmerged[idx] = inplace_buffer_merge(self.unmerged[idx], ids, delay, model=self.asr_model,)
self.unmerged[idx] = inplace_buffer_merge(
self.unmerged[idx],
ids,
delay,
model=self.asr_model,
)

output = []
for idx in range(self.batch_size):
Expand Down Expand Up @@ -1276,7 +1294,9 @@ def __init__(
self.alignment_basepath = alignment_basepath

def transcribe(
self, tokens_per_chunk: int, delay: int,
self,
tokens_per_chunk: int,
delay: int,
):
if self.lcs_delay < 0:
raise ValueError(
Expand All @@ -1302,7 +1322,10 @@ def transcribe(

if len(ids) > 0:
self.unmerged[idx] = inplace_buffer_merge(
self.unmerged[idx], ids, delay, model=self.asr_model,
self.unmerged[idx],
ids,
delay,
model=self.asr_model,
)

else:
Expand Down Expand Up @@ -1588,15 +1611,23 @@ def get_input_tokens(self, sample: dict):
f"We found sample that is missing the following keys: {missing_keys}"
f"Please ensure that every utterance in the input manifests contains these keys. Sample: {sample}"
)
tokens = canary_prompt(
tokenizer=self.asr_model.tokenizer,
text=None,
language=None,
source_language=sample['source_lang'],
target_language=sample['target_lang'],
taskname=sample['taskname'],
pnc=sample['pnc'],
)
formatter = CanaryPromptFormatter(self.asr_model.tokenizer)
pzelasko marked this conversation as resolved.
Show resolved Hide resolved
tokens = formatter.encode_dialog(
turns=[
{
"role": "user",
"slots": map_manifest_values_to_special_tokens(
{
"|SOURCE_LANG|": sample["source_lang"],
"|TARGET_LANG|": sample["target_lang"],
"|PNC|": sample["pnc"],
"|TASKNAME|": sample["taskname"],
"|PROMPT_LANGUAGE|": "spl_tokens",
}
),
}
]
)["context_ids"]
else:
raise ValueError(f"Unknown prompt format: {self.asr_model.prompt_format}")
return torch.tensor(tokens, dtype=torch.long, device=self.asr_model.device).unsqueeze(0) # [1, T]
Expand Down Expand Up @@ -1712,12 +1743,16 @@ def _get_batch_preds(self, keep_logits=False):
encoded, encoded_len = results
log_probs = self.asr_model.ctc_decoder(encoder_output=encoded)
transcribed_texts, _ = self.asr_model.ctc_decoding.ctc_decoder_predictions_tensor(
decoder_outputs=log_probs, decoder_lengths=encoded_len, return_hypotheses=False,
decoder_outputs=log_probs,
decoder_lengths=encoded_len,
return_hypotheses=False,
)
else:
log_probs, encoded_len, predictions = results
transcribed_texts, _ = self.asr_model.decoding.ctc_decoder_predictions_tensor(
decoder_outputs=log_probs, decoder_lengths=encoded_len, return_hypotheses=False,
decoder_outputs=log_probs,
decoder_lengths=encoded_len,
return_hypotheses=False,
)

self.all_preds.extend(transcribed_texts)
Expand Down
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a thought.. can we also a sample.py/simple.py template with the simplest possible template and add few comments about which routines need to be defined. (This is mainly coming from - if a user wants to create their own custom template; I know there are plenty of examples already.. )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this, yeah a canonical form of template to copy paste and directly modify

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Empty file.
Loading
Loading