-
Notifications
You must be signed in to change notification settings - Fork 32k
feat: Whisper prompting #22496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Whisper prompting #22496
Changes from all commits
e4e1315
a8a28ec
dcaad98
c5d3ab5
052d60d
ce324b0
3097061
5a122ed
8c638b7
1ce5a1e
4a45a86
2b12d21
f5f2ab6
45992aa
5dcba16
f57577b
1f5e596
2ce4035
17d1046
bb4d2f5
44a1d08
9ab0c6c
3962906
af5d8e8
0a92f36
ba4c652
cae449e
7184f5f
26cb7e7
7c2a1d2
f0df1f1
0add3c7
f0a0364
1af5e8f
caff8be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,7 +34,12 @@ | |
| SequenceClassifierOutput, | ||
| ) | ||
| from ...modeling_utils import PreTrainedModel | ||
| from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings | ||
| from ...utils import ( | ||
| add_start_docstrings, | ||
| add_start_docstrings_to_model_forward, | ||
| logging, | ||
| replace_return_docstrings, | ||
| ) | ||
| from .configuration_whisper import WhisperConfig | ||
| from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE | ||
|
|
||
|
|
@@ -1464,6 +1469,7 @@ def generate( | |
| task=None, | ||
| language=None, | ||
| is_multilingual=None, | ||
| prompt_ids: Optional[torch.Tensor] = None, | ||
| **kwargs, | ||
| ): | ||
| """ | ||
|
|
@@ -1521,6 +1527,11 @@ def generate( | |
| find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary. | ||
| is_multilingual (`bool`, *optional*): | ||
| Whether or not the model is multilingual. | ||
| prompt_ids (`torch.Tensor`, *optional*): | ||
| Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is | ||
sanchit-gandhi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for | ||
| transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words | ||
| correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value. | ||
| kwargs: | ||
| Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be | ||
| forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder | ||
|
|
@@ -1567,8 +1578,21 @@ def generate( | |
| if task is not None: | ||
| generation_config.task = task | ||
|
|
||
| forced_decoder_ids = [] | ||
| if task is not None or language is not None: | ||
| forced_decoder_ids = None | ||
|
|
||
| # Legacy code for backward compatibility | ||
| if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None: | ||
| forced_decoder_ids = self.config.forced_decoder_ids | ||
| elif ( | ||
| hasattr(self.generation_config, "forced_decoder_ids") | ||
| and self.generation_config.forced_decoder_ids is not None | ||
| ): | ||
| forced_decoder_ids = self.generation_config.forced_decoder_ids | ||
| else: | ||
| forced_decoder_ids = kwargs.get("forced_decoder_ids", None) | ||
|
|
||
| if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None): | ||
|
||
| forced_decoder_ids = [] | ||
| if hasattr(generation_config, "language"): | ||
| if generation_config.language in generation_config.lang_to_id.keys(): | ||
| language_token = generation_config.language | ||
|
|
@@ -1593,27 +1617,48 @@ def generate( | |
| raise ValueError( | ||
| f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`" | ||
| ) | ||
| else: | ||
| elif hasattr(generation_config, "task_to_id"): | ||
|
Comment on lines
-1596
to
+1620
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @amyeroberts, comments addressed Only callout I have is that I updated this line to handle the case where an english-only model is being used and the |
||
| forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe | ||
| if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps: | ||
| idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 | ||
| forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) | ||
|
|
||
| # Legacy code for backward compatibility | ||
| elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None: | ||
| forced_decoder_ids = self.config.forced_decoder_ids | ||
| elif ( | ||
| hasattr(self.generation_config, "forced_decoder_ids") | ||
| and self.generation_config.forced_decoder_ids is not None | ||
| ): | ||
| forced_decoder_ids = self.generation_config.forced_decoder_ids | ||
| if forced_decoder_ids is not None: | ||
| generation_config.forced_decoder_ids = forced_decoder_ids | ||
|
|
||
| if prompt_ids is not None: | ||
| if kwargs.get("decoder_start_token_id") is not None: | ||
| raise ValueError( | ||
| "When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten." | ||
| ) | ||
| prompt_ids = prompt_ids.tolist() | ||
| decoder_start_token_id, *text_prompt_ids = prompt_ids | ||
| # Set the decoder_start_token_id to <|startofprev|> | ||
| kwargs.update({"decoder_start_token_id": decoder_start_token_id}) | ||
|
|
||
| # Update the max generation length to include the prompt | ||
sanchit-gandhi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| specified_max_length = kwargs.pop("max_new_tokens", None) or kwargs.pop("max_length", None) | ||
| default_max_length = generation_config.max_new_tokens or generation_config.max_length | ||
| non_prompt_max_length = specified_max_length or default_max_length | ||
| kwargs["max_new_tokens"] = non_prompt_max_length + len(text_prompt_ids) | ||
|
|
||
| # Reformat the forced_decoder_ids to incorporate the prompt | ||
| non_prompt_forced_decoder_ids = ( | ||
| kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids | ||
| ) | ||
|
||
| forced_decoder_ids = [ | ||
| # Slicing the text prompt ids in a manner consistent with the OpenAI implementation | ||
| # to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599) | ||
| *text_prompt_ids[-self.config.max_length // 2 - 1 :], | ||
|
||
| generation_config.decoder_start_token_id, | ||
| *[token for _rank, token in non_prompt_forced_decoder_ids], | ||
sanchit-gandhi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ] | ||
| forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)] | ||
| generation_config.forced_decoder_ids = forced_decoder_ids | ||
|
|
||
| if generation_config.return_timestamps: | ||
| logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)] | ||
|
|
||
| if len(forced_decoder_ids) > 0: | ||
| generation_config.forced_decoder_ids = forced_decoder_ids | ||
|
|
||
| return super().generate( | ||
| inputs, | ||
| generation_config, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -606,6 +606,11 @@ def _decode( | |||||||||||||||||||||||||||||||||||||
| ) -> str: | ||||||||||||||||||||||||||||||||||||||
| self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if skip_special_tokens: | ||||||||||||||||||||||||||||||||||||||
| prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>") | ||||||||||||||||||||||||||||||||||||||
| decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>") | ||||||||||||||||||||||||||||||||||||||
| token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # To avoid mixing byte-level and unicode for byte-level BPT | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -714,6 +719,31 @@ def _decode_asr(self, model_outputs, *, return_timestamps, return_language, time | |||||||||||||||||||||||||||||||||||||
| time_precision=time_precision, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def get_prompt_ids(self, text: str, return_tensors="np"): | ||||||||||||||||||||||||||||||||||||||
| """Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`].""" | ||||||||||||||||||||||||||||||||||||||
| batch_encoding = self("<|startofprev|>", text.strip(), add_prefix_space=True, add_special_tokens=False) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Check for special tokens | ||||||||||||||||||||||||||||||||||||||
| prompt_text_ids = batch_encoding["input_ids"][1:] | ||||||||||||||||||||||||||||||||||||||
| special_token_id = next((x for x in prompt_text_ids if x >= self.all_special_ids[0]), None) | ||||||||||||||||||||||||||||||||||||||
| if special_token_id is not None: | ||||||||||||||||||||||||||||||||||||||
| token = self.convert_ids_to_tokens(special_token_id) | ||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"Encountered text in the prompt corresponding to disallowed special token: {token}.") | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| batch_encoding.convert_to_tensors(tensor_type=return_tensors) | ||||||||||||||||||||||||||||||||||||||
| return batch_encoding["input_ids"] | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||||
| def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int): | ||||||||||||||||||||||||||||||||||||||
| has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id | ||||||||||||||||||||||||||||||||||||||
| if has_prompt: | ||||||||||||||||||||||||||||||||||||||
| if decoder_start_token_id in token_ids: | ||||||||||||||||||||||||||||||||||||||
| return token_ids[token_ids.index(decoder_start_token_id) :] | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| return [] | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| return token_ids | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+737
to
+745
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: suggestion, feel free to ignore I would write for early returns to make the logic a bit clearer here
Suggested change
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision): | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -312,6 +312,11 @@ def decode( | |
| return text | ||
|
|
||
| def _decode(self, *args, normalize: bool = False, **kwargs) -> str: | ||
| if kwargs["skip_special_tokens"]: | ||
| prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>") | ||
| decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>") | ||
| kwargs["token_ids"] = self._strip_prompt(kwargs["token_ids"], prompt_token_id, decoder_start_token_id) | ||
|
|
||
| text = super()._decode(*args, **kwargs) | ||
|
|
||
| if normalize: | ||
|
|
@@ -485,3 +490,30 @@ def _decode_asr(self, model_outputs, *, return_timestamps, return_language, time | |
| return_language=return_language, | ||
| time_precision=time_precision, | ||
| ) | ||
|
|
||
|
||
| # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.get_prompt_ids | ||
| def get_prompt_ids(self, text: str, return_tensors="np"): | ||
| """Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`].""" | ||
| batch_encoding = self("<|startofprev|>", text.strip(), add_prefix_space=True, add_special_tokens=False) | ||
|
||
|
|
||
| # Check for special tokens | ||
| prompt_text_ids = batch_encoding["input_ids"][1:] | ||
| special_token_id = next((x for x in prompt_text_ids if x >= self.all_special_ids[0]), None) | ||
| if special_token_id is not None: | ||
| token = self.convert_ids_to_tokens(special_token_id) | ||
| raise ValueError(f"Encountered text in the prompt corresponding to disallowed special token: {token}.") | ||
|
|
||
| batch_encoding.convert_to_tensors(tensor_type=return_tensors) | ||
| return batch_encoding["input_ids"] | ||
|
|
||
| @staticmethod | ||
| # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._strip_prompt | ||
| def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int): | ||
| has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id | ||
| if has_prompt: | ||
| if decoder_start_token_id in token_ids: | ||
| return token_ids[token_ids.index(decoder_start_token_id) :] | ||
| else: | ||
| return [] | ||
|
|
||
| return token_ids | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1013,6 +1013,48 @@ def test_mask_time_prob(self): | |
| encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state | ||
| self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16)) | ||
|
|
||
| def test_generate_with_prompt_ids_and_task_and_language(self): | ||
| config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
| model = WhisperForConditionalGeneration(config).eval().to(torch_device) | ||
| input_features = input_dict["input_features"] | ||
| prompt_ids = np.arange(5) | ||
| language = "<|de|>" | ||
| task = "translate" | ||
| lang_id = 6 | ||
| task_id = 7 | ||
| model.generation_config.__setattr__("lang_to_id", {language: lang_id}) | ||
| model.generation_config.__setattr__("task_to_id", {task: task_id}) | ||
|
|
||
| output = model.generate(input_features, max_new_tokens=5, task=task, language=language, prompt_ids=prompt_ids) | ||
|
|
||
| expected_output_start = [ | ||
| *prompt_ids.tolist(), | ||
| model.generation_config.decoder_start_token_id, | ||
| lang_id, | ||
| task_id, | ||
| ] | ||
| for row in output.tolist(): | ||
| self.assertListEqual(row[: len(expected_output_start)], expected_output_start) | ||
|
|
||
| def test_generate_with_prompt_ids_and_forced_decoder_ids(self): | ||
| config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
| model = WhisperForConditionalGeneration(config).eval().to(torch_device) | ||
| input_features = input_dict["input_features"] | ||
| prompt_ids = np.asarray(range(5)) | ||
| forced_decoder_ids = [(1, 6), (2, 7), (3, 8)] | ||
|
|
||
| output = model.generate( | ||
| input_features, max_new_tokens=5, forced_decoder_ids=forced_decoder_ids, prompt_ids=prompt_ids | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we allow passing |
||
| ) | ||
|
|
||
| expected_output_start = [ | ||
| *prompt_ids.tolist(), | ||
| model.generation_config.decoder_start_token_id, | ||
| *[token for _rank, token in forced_decoder_ids], | ||
| ] | ||
| for row in output.tolist(): | ||
| self.assertListEqual(row[: len(expected_output_start)], expected_output_start) | ||
|
|
||
|
|
||
| @require_torch | ||
| @require_torchaudio | ||
|
|
@@ -1429,6 +1471,60 @@ def test_tiny_specaugment_librispeech(self): | |
| # fmt: on | ||
| self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4)) | ||
|
|
||
| @slow | ||
| def test_generate_with_prompt_ids(self): | ||
| processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") | ||
| model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") | ||
| model.to(torch_device) | ||
| input_speech = self._load_datasamples(4)[-1:] | ||
| input_features = processor(input_speech, return_tensors="pt").input_features | ||
|
|
||
| output_without_prompt = model.generate(input_features) | ||
| prompt_ids = processor.get_prompt_ids("Leighton") | ||
| output_with_prompt = model.generate(input_features, prompt_ids=prompt_ids) | ||
|
|
||
| expected_without_prompt = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>" | ||
| expected_with_prompt = "<|startofprev|> Leighton<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Leighton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>" | ||
| self.assertEqual(processor.decode(output_without_prompt[0]), expected_without_prompt) | ||
| self.assertEqual(processor.decode(output_with_prompt[0]), expected_with_prompt) | ||
|
|
||
| @slow | ||
| def test_generate_with_prompt_ids_and_forced_decoder_ids(self): | ||
|
||
| processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") | ||
| model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") | ||
| model.to(torch_device) | ||
| input_speech = self._load_datasamples(1) | ||
| input_features = processor(input_speech, return_tensors="pt").input_features | ||
| task = "translate" | ||
| language = "de" | ||
| expected_tokens = [f"<|{task}|>", f"<|{language}|>"] | ||
| prompt = "test prompt" | ||
| prompt_ids = processor.get_prompt_ids(prompt) | ||
|
|
||
| output = model.generate(input_features, task=task, language=language, prompt_ids=prompt_ids) | ||
| text = processor.decode(output[0]) | ||
|
|
||
| self.assertTrue(prompt in text) | ||
| self.assertTrue(all([token in text for token in expected_tokens])) | ||
|
|
||
| @slow | ||
| def test_generate_with_prompt_ids_and_no_non_prompt_forced_decoder_ids(self): | ||
| processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") | ||
| model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") | ||
| model.to(torch_device) | ||
| input_speech = self._load_datasamples(1) | ||
| input_features = processor(input_speech, return_tensors="pt").input_features | ||
| prompt = "test prompt" | ||
| prompt_ids = processor.get_prompt_ids(prompt) | ||
|
|
||
| model.generation_config.forced_decoder_ids = None | ||
| model.config.forced_decoder_ids = None | ||
|
|
||
| output = model.generate(input_features, prompt_ids=prompt_ids, return_timestamps=True) | ||
| text = processor.decode(output[0]) | ||
|
|
||
| self.assertTrue(prompt in text) | ||
|
|
||
|
|
||
| def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): | ||
| if head_mask is None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,8 @@ | |
| import tempfile | ||
| import unittest | ||
|
|
||
| import pytest | ||
|
|
||
| from transformers import WhisperTokenizer, is_speech_available | ||
| from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio | ||
|
|
||
|
|
@@ -146,3 +148,32 @@ def test_get_decoder_prompt_ids(self): | |
|
|
||
| expected_ids = [TRANSCRIBE, NOTIMESTAMPS] | ||
| self.assertListEqual([ids[-1] for ids in forced_decoder_ids], expected_ids) | ||
|
|
||
| def test_get_prompt_ids(self): | ||
|
||
| processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor()) | ||
| prompt_ids = processor.get_prompt_ids("Mr. Quilter") | ||
| decoded_prompt = processor.tokenizer.decode(prompt_ids) | ||
|
|
||
| self.assertListEqual(prompt_ids.tolist(), [50360, 1770, 13, 2264, 346, 353]) | ||
| self.assertEqual(decoded_prompt, "<|startofprev|> Mr. Quilter") | ||
|
|
||
| def test_empty_get_prompt_ids(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice :) |
||
| processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor()) | ||
| prompt_ids = processor.get_prompt_ids("") | ||
| decoded_prompt = processor.tokenizer.decode(prompt_ids) | ||
|
|
||
| self.assertListEqual(prompt_ids.tolist(), [50360, 220]) | ||
| self.assertEqual(decoded_prompt, "<|startofprev|> ") | ||
|
|
||
| def test_get_prompt_ids_with_special_tokens(self): | ||
sanchit-gandhi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor()) | ||
|
|
||
| def _test_prompt_error_raised_helper(prompt, special_token): | ||
| with pytest.raises(ValueError) as excinfo: | ||
| processor.get_prompt_ids(prompt) | ||
| expected = f"Encountered text in the prompt corresponding to disallowed special token: {special_token}." | ||
| self.assertEqual(expected, str(excinfo.value)) | ||
|
|
||
| _test_prompt_error_raised_helper("<|startofprev|> test", "<|startofprev|>") | ||
| _test_prompt_error_raised_helper("test <|notimestamps|>", "<|notimestamps|>") | ||
| _test_prompt_error_raised_helper("test <|zh|> test <|transcribe|>", "<|zh|>") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
prompt_idsshould not be allowed to be a numpy array given its signature (see: https://github.com/huggingface/transformers/pull/22496/files#r1467369773)