Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 22 additions & 23 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,35 +917,34 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
probs to `inf` so that they are sampled at their corresponding index.

Args:
begin_index (`int`, *optional*, defaults to 5 ):
This indicates to the processor where the first tokens are generated. This is used to differentiate between
the `prompt` tokens and the `generated` tokens. When generating with `WhisperForConditionalGeneration` the
`prompt` tokens are the first 4 tokens.
eos_token_id (`int`, *optional*, defaults to 50257):
The id of the *end-of-sequence* token.
no_timestamps_token_id (`int`, *optional*, defaults to 50363):
The id of the `"<|notimestamps|>"` token.
max_initial_timestamp (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from predicting
timestamps that are too far in the future.
generate_config (`GenerateConfig`):
The generate config used to generate the output. The following parameters are required:
eos_token_id (`int`, *optional*, defaults to 50257):
The id of the *end-of-sequence* token.
no_timestamps_token_id (`int`, *optional*, defaults to 50363):
The id of the `"<|notimestamps|>"` token.
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future.
"""

def __init__(
self,
begin_index=5,
eos_token_id=50257,
no_timestamps_token_id=50363,
max_initial_timestamp=1,
):
self.eos_token_id = eos_token_id
self.no_timestamps_token_id = no_timestamps_token_id
self.timestamp_begin = no_timestamps_token_id + 1
self.begin_index = begin_index
self.max_initial_timestamp_index = max_initial_timestamp
def __init__(self, generate_config): # support for the kwargs
self.eos_token_id = generate_config.eos_token_id
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1

self.begin_index = len(generate_config.forced_decoder_ids) + 1
if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id:
self.begin_index -= 1
if generate_config.is_multilingual:
self.begin_index += 1
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index

def __call__(self, input_ids, scores):
# suppress <|notimestamps|> which is handled by without_timestamps
scores[:, self.no_timestamps_token_id] = -float("inf")
if input_ids.shape[1] == self.begin_index:
scores[:, self.timestamp_begin] = 0

# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
for k in range(input_ids.shape[0]):
Expand Down
145 changes: 145 additions & 0 deletions src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...generation.logits_process import WhisperTimeStampLogitsProcessor
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Expand Down Expand Up @@ -1231,6 +1232,150 @@ def forward(
encoder_attentions=outputs.encoder_attentions,
)

def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config=None,
logits_processor=None,
stopping_criteria=None,
prefix_allowed_tokens_fn=None,
synced_gpus=False,
return_timestamps=None,
task=None,
language=None,
is_multilingual=None,
**kwargs
):
"""

Generates sequences of token ids for models with a language modeling head.

<Tip warning={true}>

Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.

For an overview of generation strategies and code examples, check out the [following
guide](./generation_strategies).

</Tip>

Parameters:
inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
`input_ids`, `input_values`, `input_features`, or `pixel_values`.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904).
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
return_timestamps (`bool`, *optional*):
Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
task (`bool`, *optional*):
Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
will be updated accordingly.
language (`bool`, *optional*):
Language token to use for generation, should be in the form `<|en|>`. You can find all the possible
language tokens in the `model.generation_config.lang_to_id` dictionary.
is_multilingual (`bool`, *optional*):
Whether or not the model is multilingual.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.

Return:
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.

If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~utils.ModelOutput`] types are:

- [`~generation.GreedySearchDecoderOnlyOutput`],
- [`~generation.SampleDecoderOnlyOutput`],
- [`~generation.BeamSearchDecoderOnlyOutput`],
- [`~generation.BeamSampleDecoderOnlyOutput`]

If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:

- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
"""
if generation_config is None:
generation_config = self.generation_config

if return_timestamps is not None:
generation_config.return_timestamps = return_timestamps

if task is not None:
generation_config.task = task

if is_multilingual is not None:
generation_config.is_multilingual = is_multilingual

if language is not None:
generation_config.language = language

forced_decoder_ids = []

if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where we first introduced the generation config. Unless the task and language were passed as inputs, we'd default to speech transcription with language detection

if hasattr(generation_config, "language"):
forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language]))
else:
forced_decoder_ids.append((1, None))

if hasattr(generation_config, "task"):
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
else:
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))

if (
hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps
) or return_timestamps:
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
else:
if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))

if len(forced_decoder_ids) > 0:
generation_config.forced_decoder_ids = forced_decoder_ids

return super().generate(
inputs,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
**kwargs,
)

def prepare_inputs_for_generation(
self,
decoder_input_ids,
Expand Down
26 changes: 25 additions & 1 deletion src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,23 @@ def _normalize(self, text):
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
return normalizer(text)

def _decode_with_timestamps(self, token_ids, time_precision=0.02) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
"""
timestamp_begin = self.all_special_ids[-1] + 1
outputs = [[]]
for token in token_ids:
if token >= timestamp_begin:
timestamp = f"<|{(token - timestamp_begin) * time_precision:.2f}|>"
outputs.append(timestamp)
outputs.append([])
else:
outputs[-1].append(token)
outputs = [s if isinstance(s, str) else self.decode(s) for s in outputs]
return "".join(outputs)

def _compute_offsets(self, token_ids, time_precision=0.02):
"""
Compute offsets for a given tokenized input
Expand Down Expand Up @@ -544,6 +561,7 @@ def decode(
clean_up_tokenization_spaces: bool = True,
output_offsets: bool = False,
time_precision=0.02,
decode_with_timestamps: bool = False,
**kwargs
) -> str:
"""
Expand All @@ -561,7 +579,11 @@ def decode(
Whether or not to clean up the tokenization spaces.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.

output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
WHether or not to decode with timestamps included in the raw text.
Returns:
`str`: The decoded sentence.
"""
Expand All @@ -571,6 +593,8 @@ def decode(
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
if decode_with_timestamps:
text = self._decode_with_timestamps(token_ids, time_precision=time_precision)
# retrieve offsets
if output_offsets:
offsets = None
Expand Down
28 changes: 7 additions & 21 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
logger = logging.get_logger(__name__)

if is_torch_available():
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor

from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING


Expand Down Expand Up @@ -413,13 +411,6 @@ def _sanitize_parameters(
if return_timestamps is not None:
forward_params["return_timestamps"] = return_timestamps
postprocess_params["return_timestamps"] = return_timestamps
if self.model.config.model_type == "whisper":
# Whisper is highly specific, if we want timestamps, we need to
# force whisper to output timestamp tokens, which means we need
# to set this variable to prevent `no_timestamp_token` to be
# used in the decoder.
if "forced_decoder_ids" not in forward_params.get("generate_kwargs", {}):
forward_params["generate_kwargs"]["forced_decoder_ids"] = None

return preprocess_params, forward_params, postprocess_params

Expand Down Expand Up @@ -529,10 +520,11 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warn
def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
if generate_kwargs is None:
generate_kwargs = {}

if return_timestamps and self.type == "seq2seq_whisper":
generate_kwargs["return_timestamps"] = return_timestamps
is_last = model_inputs.pop("is_last")

if self.type == "seq2seq":
if self.type in {"seq2seq", "seq2seq_whisper"}:
encoder = self.model.get_encoder()
# Consume values so we can let extra information flow freely through
# the pipeline (important for `partial` in microphone)
Expand All @@ -557,16 +549,10 @@ def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
**generate_kwargs,
)
out = {"tokens": tokens}
elif self.type == "seq2seq_whisper":
stride = model_inputs.pop("stride", None)
tokens = self.model.generate(
input_features=model_inputs.pop("input_features"),
logits_processor=[WhisperTimeStampLogitsProcessor()] if return_timestamps else None,
**generate_kwargs,
)
out = {"tokens": tokens}
if stride is not None:
out["stride"] = stride
if self.type == "seq2seq_whisper":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We needed to pop before generate.

If there's no need for it to pop before we can simplify be simply setting something like:

out = {**out, **model_inputs} or something slightly along those lines including only stride.

stride = model_inputs.pop("stride", None)
if stride is not None:
out["stride"] = stride

else:
stride = model_inputs.pop("stride", None)
Expand Down
12 changes: 10 additions & 2 deletions tests/models/whisper/test_tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_get_vocab(self):
self.assertEqual(len(vocab_keys), 50364)

def test_vocab_size(self):
self.assertEqual(self.get_tokenizer().vocab_size, 50257)
self.assertEqual(self.get_tokenizer().vocab_size, 50258)

def test_full_tokenizer(self):
tokenizer = WhisperTokenizer.from_pretrained(self.tmpdirname)
Expand Down Expand Up @@ -265,7 +265,15 @@ def test_offset_decoding(self):
},
],
)

# test `decode_with_offsets`
output = multilingual_tokenizer.decode(INPUT_TOKENS, decode_with_timestamps=True)
self.assertEqual(
output,
"<|startoftranscript|><|en|><|transcribe|><|0.00|> Lennils, pictures are a sort of upguards and atom"
" paintings, and Mason's exquisite idles<|7.20|><|7.20|> are as national as a jingo poem. Mr. Birkut"
" Foster's landscapes smile at one much in the<|15.16|><|15.16|> same way that Mr. Carker used to flash"
" his teeth. And Mr. John Colier gives his<|21.70|><|21.70|><|endoftext|>",
)
# test a single sequence with timestamps
# fmt: off
INPUT_TOKENS = [
Expand Down
Loading