Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
e4e1315
initial working additions
Mar 30, 2023
a8a28ec
clean and rename, add cond stripping initial prompt to decode
Mar 30, 2023
dcaad98
cleanup, edit create_initial_prompt_ids, add tests
Mar 30, 2023
c5d3ab5
repo consistency, flip order of conditional
Mar 31, 2023
052d60d
fix error, move the processor fn to the tokenizer
Mar 31, 2023
ce324b0
repo consistency, update test ids to corresponding tokenizer
Mar 31, 2023
3097061
use convert_tokens_to_ids not get_vocab...
Mar 31, 2023
5a122ed
use actual conditional in generate
Apr 3, 2023
8c638b7
make sytle
Apr 3, 2023
1ce5a1e
initial address comments
Apr 5, 2023
4a45a86
initial working add new params to pipeline
Apr 10, 2023
2b12d21
first draft of sequential generation for condition_on_previous_text
Apr 11, 2023
f5f2ab6
add/update tests, make compatible with timestamps
Apr 13, 2023
45992aa
make compatible with diff. input kwargs and max length
Apr 14, 2023
5dcba16
add None check
Apr 14, 2023
f57577b
add temperature check
Apr 14, 2023
1f5e596
flip temp check operand
Apr 14, 2023
2ce4035
refocusing to prev pr scope
Apr 25, 2023
17d1046
remove the params too
Apr 26, 2023
bb4d2f5
make style
Apr 26, 2023
44a1d08
edits, move max length incorporating prompt to whisper
Apr 26, 2023
9ab0c6c
address comments
May 3, 2023
3962906
remove asr pipeline prompt decoding, fix indexing
May 3, 2023
af5d8e8
address comments (more tests, validate prompt)
May 3, 2023
0a92f36
un-comment out tests (from debug)
May 3, 2023
ba4c652
remove old comment
May 3, 2023
cae449e
address comments
May 10, 2023
7184f5f
fix typo
May 10, 2023
26cb7e7
remove timestamp token from test
May 10, 2023
7c2a1d2
make style
May 10, 2023
f0df1f1
cleanup
May 10, 2023
0add3c7
copy method to fast tokenizer, set max_new_tokens for test
May 10, 2023
f0a0364
prompt_ids type just pt
May 16, 2023
1af5e8f
address Amy's comments
May 17, 2023
caff8be
make style
May 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 60 additions & 15 deletions src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@
SequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_whisper import WhisperConfig
from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE

Expand Down Expand Up @@ -1464,6 +1469,7 @@ def generate(
task=None,
language=None,
is_multilingual=None,
prompt_ids: Optional[torch.Tensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think prompt_ids should not be allowed to be a numpy array given its signature (see: https://github.com/huggingface/transformers/pull/22496/files#r1467369773)

**kwargs,
):
"""
Expand Down Expand Up @@ -1521,6 +1527,11 @@ def generate(
find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary.
is_multilingual (`bool`, *optional*):
Whether or not the model is multilingual.
prompt_ids (`torch.Tensor`, *optional*):
Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is
provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
Expand Down Expand Up @@ -1567,8 +1578,21 @@ def generate(
if task is not None:
generation_config.task = task

forced_decoder_ids = []
if task is not None or language is not None:
forced_decoder_ids = None

# Legacy code for backward compatibility
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
forced_decoder_ids = self.config.forced_decoder_ids
elif (
hasattr(self.generation_config, "forced_decoder_ids")
and self.generation_config.forced_decoder_ids is not None
):
forced_decoder_ids = self.generation_config.forced_decoder_ids
else:
forced_decoder_ids = kwargs.get("forced_decoder_ids", None)

if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done solely for handling the case where prompt_ids are passed in but the generation config and model config's forced decoder ids are both None. Its essentially just changing the order of operations so that we can cleanly check forced_decoder_ids is None and prompt_ids is not None to then add non-prompt forced decoder ids, none of the other functionality should change

forced_decoder_ids = []
if hasattr(generation_config, "language"):
if generation_config.language in generation_config.lang_to_id.keys():
language_token = generation_config.language
Expand All @@ -1593,27 +1617,48 @@ def generate(
raise ValueError(
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
)
else:
elif hasattr(generation_config, "task_to_id"):
Comment on lines -1596 to +1620
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @amyeroberts, comments addressed

Only callout I have is that I updated this line to handle the case where an english-only model is being used and the model.config.forced_decoder_ids and generation_config.forced_decoder_ids are both None, since the generation_config doesn't have task_to_id in that case

forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))

# Legacy code for backward compatibility
elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
forced_decoder_ids = self.config.forced_decoder_ids
elif (
hasattr(self.generation_config, "forced_decoder_ids")
and self.generation_config.forced_decoder_ids is not None
):
forced_decoder_ids = self.generation_config.forced_decoder_ids
if forced_decoder_ids is not None:
generation_config.forced_decoder_ids = forced_decoder_ids

if prompt_ids is not None:
if kwargs.get("decoder_start_token_id") is not None:
raise ValueError(
"When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten."
)
prompt_ids = prompt_ids.tolist()
decoder_start_token_id, *text_prompt_ids = prompt_ids
# Set the decoder_start_token_id to <|startofprev|>
kwargs.update({"decoder_start_token_id": decoder_start_token_id})

# Update the max generation length to include the prompt
specified_max_length = kwargs.pop("max_new_tokens", None) or kwargs.pop("max_length", None)
default_max_length = generation_config.max_new_tokens or generation_config.max_length
non_prompt_max_length = specified_max_length or default_max_length
kwargs["max_new_tokens"] = non_prompt_max_length + len(text_prompt_ids)

# Reformat the forced_decoder_ids to incorporate the prompt
non_prompt_forced_decoder_ids = (
kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids
)
Comment on lines +1646 to +1648
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I think this now supports the different ways that forced_decoder_ids may be passed in?

  1. Through model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language=..., task=...)

  2. Through model.generate(input_features, forced_decoder_ids=forced_decoder_ids)

  3. Through model.generate(input_features, language=..., task=...)

It would be good if there are unit tests for these different methods.

Copy link
Contributor Author

@connor-henderson connor-henderson May 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe model.generate allows passing in task or language directly as in 3. above, but I've now added tests for the other two

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does allow that (and I think it might even be the preferred method now) but for some reason the language needs to be the token, such as "<|ja|>" rather than "ja".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @gante

Copy link
Contributor

@gante gante May 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@connor-henderson update on the language code: we now support passing the language token, the language code, or the language name. See this (very recent) PR :)

(not sure if this info has gotten to you, many conversations in parallel in this PR)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the language code change was @connor-henderson's most recent PR! This forced_generation_ids logic is in-place so that the code is backward compatible with our previous way of handling the langauge/task, where we either set it in the config as config.forced_decoder_ids, or explicitly as forced_decoder_ids to the generate method (see #21965 (comment))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha derp, I didn't look at the author 🙈 my bad!

forced_decoder_ids = [
# Slicing the text prompt ids in a manner consistent with the OpenAI implementation
# to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599)
*text_prompt_ids[-self.config.max_length // 2 - 1 :],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason behind this slicing? Intuitively it makes sense to me, but I'm curious to know if there is a reference behind this choice :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I'll leave a comment in the code too, this is done to match Whisper's implementation. I believe the reason they do the -1 is to make room for the first token to generate, and the reason they do // 2 is to halve it to share context space with a prefix if one is provided (which also gets halved). I don't believe there's prefix support yet in transformers so technically the // 2 isn't necessary at this point but I didn't want to confuse future work around that if it happens. There's a good clarification of prompt vs prefix here if it's of interest.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @connor-henderson, as I am using the prompting feature I noticed a bug for long prompts. It might be caused by the slicing, where it should be text_prompt_ids = text_prompt_ids[-(self.config.max_length // 2 - 1) :], to correctly account for the first token <|startofprev|>.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Helene-Maxcici, feel free to open a new issue to track this bug, tagging myself (and optionally @connor-henderson). In particular, it would be super helpful to have a reproducible code snippet to emulate the behaviour locally. See the following page for details: https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#submitting-a-bug-related-issue-or-feature-request

generation_config.decoder_start_token_id,
*[token for _rank, token in non_prompt_forced_decoder_ids],
]
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)]
generation_config.forced_decoder_ids = forced_decoder_ids

if generation_config.return_timestamps:
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]

if len(forced_decoder_ids) > 0:
generation_config.forced_decoder_ids = forced_decoder_ids

return super().generate(
inputs,
generation_config,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/whisper/processing_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Speech processor class for Whisper
"""


from ...processing_utils import ProcessorMixin


Expand Down Expand Up @@ -91,3 +92,6 @@ def decode(self, *args, **kwargs):
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)

def get_prompt_ids(self, text: str, return_tensors="np"):
return self.tokenizer.get_prompt_ids(text, return_tensors=return_tensors)
30 changes: 30 additions & 0 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,11 @@ def _decode(
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)

if skip_special_tokens:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)

filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)

# To avoid mixing byte-level and unicode for byte-level BPT
Expand Down Expand Up @@ -714,6 +719,31 @@ def _decode_asr(self, model_outputs, *, return_timestamps, return_language, time
time_precision=time_precision,
)

def get_prompt_ids(self, text: str, return_tensors="np"):
"""Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`]."""
batch_encoding = self("<|startofprev|>", text.strip(), add_prefix_space=True, add_special_tokens=False)

# Check for special tokens
prompt_text_ids = batch_encoding["input_ids"][1:]
special_token_id = next((x for x in prompt_text_ids if x >= self.all_special_ids[0]), None)
if special_token_id is not None:
token = self.convert_ids_to_tokens(special_token_id)
raise ValueError(f"Encountered text in the prompt corresponding to disallowed special token: {token}.")

batch_encoding.convert_to_tensors(tensor_type=return_tensors)
return batch_encoding["input_ids"]

@staticmethod
def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id
if has_prompt:
if decoder_start_token_id in token_ids:
return token_ids[token_ids.index(decoder_start_token_id) :]
else:
return []

return token_ids
Comment on lines +737 to +745
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: suggestion, feel free to ignore

I would write for early returns to make the logic a bit clearer here

Suggested change
def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id
if has_prompt:
if decoder_start_token_id in token_ids:
return token_ids[token_ids.index(decoder_start_token_id) :]
else:
return []
return token_ids
def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id
if not has_prompt:
return token_ids
if decoder_start_token_id in token_ids:
return token_ids[token_ids.index(decoder_start_token_id) :]
return []



def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
"""
Expand Down
32 changes: 32 additions & 0 deletions src/transformers/models/whisper/tokenization_whisper_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,11 @@ def decode(
return text

def _decode(self, *args, normalize: bool = False, **kwargs) -> str:
if kwargs["skip_special_tokens"]:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
kwargs["token_ids"] = self._strip_prompt(kwargs["token_ids"], prompt_token_id, decoder_start_token_id)

text = super()._decode(*args, **kwargs)

if normalize:
Expand Down Expand Up @@ -485,3 +490,30 @@ def _decode_asr(self, model_outputs, *, return_timestamps, return_language, time
return_language=return_language,
time_precision=time_precision,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we missing the get_prompt_ids method from the fast tokenizer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, thanks

# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.get_prompt_ids
def get_prompt_ids(self, text: str, return_tensors="np"):
"""Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`]."""
batch_encoding = self("<|startofprev|>", text.strip(), add_prefix_space=True, add_special_tokens=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick comment here. By default, WhisperTokenizerFast sets add_prefix_space to False. After being inititalized, the value cannot be changed (see related #17391).

from transformers import WhisperTokenizerFast           
                                                                      
tokenizer = WhisperTokenizerFast.from_pretrained("openai/whisper-tiny")
                                                                      
tokenizer("<|startofprev|>", "test", add_special_tokens=False).input_ids
> [50361, 31636]

tokenizer("<|startofprev|>", "test", add_special_tokens=False, add_prefix_space=True).input_ids   
> TypeError: PreTrainedTokenizerFast._batch_encode_plus() got an unexpected keyword argument 'add_prefix_space'

Is it necessary to have add_prefix_space be False by default? Pinging @hollance and @sanchit-gandhi too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since Whisper always infers an extra space after the <|startoftranscript|> / <|startofprev|>, I see why it could make sense to change add_prefix_space to True by default. Probably we should have set it to True when we added the model in the first place, but right now to swap it would be a breaking change since we would silently change the behaviour of the tokenizer

I'm not sure there's a clean way we can do this directly? Probably easiest is to instantiate the tokenizer with add_prefix_space=True where required:

from transformers import WhisperTokenizerFast

tokenizer = WhisperTokenizerFast.from_pretrained("openai/whisper-tiny", add_prefix_space=True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For training with initial prompts, right now we are using two tokenizers, which is not ideal.


# Check for special tokens
prompt_text_ids = batch_encoding["input_ids"][1:]
special_token_id = next((x for x in prompt_text_ids if x >= self.all_special_ids[0]), None)
if special_token_id is not None:
token = self.convert_ids_to_tokens(special_token_id)
raise ValueError(f"Encountered text in the prompt corresponding to disallowed special token: {token}.")

batch_encoding.convert_to_tensors(tensor_type=return_tensors)
return batch_encoding["input_ids"]

@staticmethod
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._strip_prompt
def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id
if has_prompt:
if decoder_start_token_id in token_ids:
return token_ids[token_ids.index(decoder_start_token_id) :]
else:
return []

return token_ids
96 changes: 96 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,48 @@ def test_mask_time_prob(self):
encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state
self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16))

def test_generate_with_prompt_ids_and_task_and_language(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = WhisperForConditionalGeneration(config).eval().to(torch_device)
input_features = input_dict["input_features"]
prompt_ids = np.arange(5)
language = "<|de|>"
task = "translate"
lang_id = 6
task_id = 7
model.generation_config.__setattr__("lang_to_id", {language: lang_id})
model.generation_config.__setattr__("task_to_id", {task: task_id})

output = model.generate(input_features, max_new_tokens=5, task=task, language=language, prompt_ids=prompt_ids)

expected_output_start = [
*prompt_ids.tolist(),
model.generation_config.decoder_start_token_id,
lang_id,
task_id,
]
for row in output.tolist():
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)

def test_generate_with_prompt_ids_and_forced_decoder_ids(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = WhisperForConditionalGeneration(config).eval().to(torch_device)
input_features = input_dict["input_features"]
prompt_ids = np.asarray(range(5))
forced_decoder_ids = [(1, 6), (2, 7), (3, 8)]

output = model.generate(
input_features, max_new_tokens=5, forced_decoder_ids=forced_decoder_ids, prompt_ids=prompt_ids
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we allow passing prompt_ids as a numpy array here?

)

expected_output_start = [
*prompt_ids.tolist(),
model.generation_config.decoder_start_token_id,
*[token for _rank, token in forced_decoder_ids],
]
for row in output.tolist():
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)


@require_torch
@require_torchaudio
Expand Down Expand Up @@ -1429,6 +1471,60 @@ def test_tiny_specaugment_librispeech(self):
# fmt: on
self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))

@slow
def test_generate_with_prompt_ids(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
input_speech = self._load_datasamples(4)[-1:]
input_features = processor(input_speech, return_tensors="pt").input_features

output_without_prompt = model.generate(input_features)
prompt_ids = processor.get_prompt_ids("Leighton")
output_with_prompt = model.generate(input_features, prompt_ids=prompt_ids)

expected_without_prompt = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>"
expected_with_prompt = "<|startofprev|> Leighton<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Leighton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>"
self.assertEqual(processor.decode(output_without_prompt[0]), expected_without_prompt)
self.assertEqual(processor.decode(output_with_prompt[0]), expected_with_prompt)

@slow
def test_generate_with_prompt_ids_and_forced_decoder_ids(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this is most readable / simplest as one test with comments clarifying the cases, lmk if you want them split into separate unit tests

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's very readable but there's a potential issue: the model will keep state around, i.e. it will change the model.generation_config object with the new forced_decoder_ids and this may affect the next test. So I think it's better to instantiate a new model before each test.

Maybe it's also a good idea to test what happens after you do the following, just to make sure the code can handle both of these things being None:

model.config.forced_decoder_ids = None
model.generation_config.forced_decoder_ids = None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a case for the above which involved a change in generate I'll call out below. I was aiming to order the tests to prevent conflicting state issues but you're right they're more brittle that way, split them into individual tests

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I wasn't able to fully understand the last comment - for testing the case when:

model.config.forced_decoder_ids = None 
model.generation_config.forced_decoder_ids = None

is this tested?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, moving parts, we had a test explicitly for this when there we 5 test cases. Then we trimmed them, per this #22496 (comment) I changed the test_generate_with_prompt_ids_and_no_non_prompt_forced_decoder_ids test to use whisper-base.en and return_timestampt=True. I just tested it tho and realized that combination didn't actually set those attributes to None, so I updated the test to explicity set those two to None.

tl;dr it was tested, then wasn't, now is again

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
input_speech = self._load_datasamples(1)
input_features = processor(input_speech, return_tensors="pt").input_features
task = "translate"
language = "de"
expected_tokens = [f"<|{task}|>", f"<|{language}|>"]
prompt = "test prompt"
prompt_ids = processor.get_prompt_ids(prompt)

output = model.generate(input_features, task=task, language=language, prompt_ids=prompt_ids)
text = processor.decode(output[0])

self.assertTrue(prompt in text)
self.assertTrue(all([token in text for token in expected_tokens]))

@slow
def test_generate_with_prompt_ids_and_no_non_prompt_forced_decoder_ids(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.to(torch_device)
input_speech = self._load_datasamples(1)
input_features = processor(input_speech, return_tensors="pt").input_features
prompt = "test prompt"
prompt_ids = processor.get_prompt_ids(prompt)

model.generation_config.forced_decoder_ids = None
model.config.forced_decoder_ids = None

output = model.generate(input_features, prompt_ids=prompt_ids, return_timestamps=True)
text = processor.decode(output[0])

self.assertTrue(prompt in text)


def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
if head_mask is None:
Expand Down
31 changes: 31 additions & 0 deletions tests/models/whisper/test_processor_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import tempfile
import unittest

import pytest

from transformers import WhisperTokenizer, is_speech_available
from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio

Expand Down Expand Up @@ -146,3 +148,32 @@ def test_get_decoder_prompt_ids(self):

expected_ids = [TRANSCRIBE, NOTIMESTAMPS]
self.assertListEqual([ids[-1] for ids in forced_decoder_ids], expected_ids)

def test_get_prompt_ids(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add some tests for edge cases?

For example: processor.get_prompt_ids("") or processor.get_prompt_ids("<|startofprev|> Mr. <|startofprev|> Quilter")

Copy link
Contributor Author

@connor-henderson connor-henderson May 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second will definitely confuse the model and decoding if they were passed to the current get_prompt_ids as is, would you prefer we strip the prompt start token or raise an error that it was included? I'll push up a change that strips it for now, lmk which you prefer and if you'd want to log a warning as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really know what would be the best approach here, was just trying to think of things that might go wrong. ;-)

Perhaps raising an error on unexpected input is the best choice, but only if it doesn't add a lot of complexity to the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like they have their tiktoken package handle it and it raises an error if any special token is included, so will look to do the same

processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
prompt_ids = processor.get_prompt_ids("Mr. Quilter")
decoded_prompt = processor.tokenizer.decode(prompt_ids)

self.assertListEqual(prompt_ids.tolist(), [50360, 1770, 13, 2264, 346, 353])
self.assertEqual(decoded_prompt, "<|startofprev|> Mr. Quilter")

def test_empty_get_prompt_ids(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :)

processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
prompt_ids = processor.get_prompt_ids("")
decoded_prompt = processor.tokenizer.decode(prompt_ids)

self.assertListEqual(prompt_ids.tolist(), [50360, 220])
self.assertEqual(decoded_prompt, "<|startofprev|> ")

def test_get_prompt_ids_with_special_tokens(self):
processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())

def _test_prompt_error_raised_helper(prompt, special_token):
with pytest.raises(ValueError) as excinfo:
processor.get_prompt_ids(prompt)
expected = f"Encountered text in the prompt corresponding to disallowed special token: {special_token}."
self.assertEqual(expected, str(excinfo.value))

_test_prompt_error_raised_helper("<|startofprev|> test", "<|startofprev|>")
_test_prompt_error_raised_helper("test <|notimestamps|>", "<|notimestamps|>")
_test_prompt_error_raised_helper("test <|zh|> test <|transcribe|>", "<|zh|>")
19 changes: 19 additions & 0 deletions tests/models/whisper/test_tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,25 @@ def test_find_longest_common_subsequence(self):
merge = _find_longest_common_sequence([seq1, seq2, seq3])
self.assertEqual(merge, [1, 2, 3, 4, 5, 6, 7, 8])

def test_skip_special_tokens_skips_prompt_ids(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
# fmt: off
encoded_input = [
50361, 2221, 13, 2326, 388, 391, 50258, 50259, 50359,
50363, 1282, 264, 2674, 9156, 295, 1523, 11, 2221, 13,
2326, 388, 391, 13657, 365, 2681, 21296, 17711, 13, 50257,
]
# fmt: on
expected_with_special_tokens = "<|startofprev|> Mr. Quilter<|startoftranscript|><|en|><|transcribe|><|notimestamps|> On the general principles of art, Mr. Quilter writes with equal lucidity.<|endoftext|>"
expected_without_special_tokens = " On the general principles of art, Mr. Quilter writes with equal lucidity."
self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens)
self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens)
self.assertEqual(rust_tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens)
self.assertEqual(
rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens
)


class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
checkpoint_name = "openai/whisper-small.en"
Expand Down