Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 15 additions & 64 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,6 @@ def generate(
self._set_language_and_task(
language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
)
self._set_token_ids(generation_config=generation_config, config=self.config, kwargs=kwargs)
self._set_num_frames(
return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
)
Expand Down Expand Up @@ -546,13 +545,13 @@ def generate(
logits_processor=logits_processor,
begin_index=begin_index, # begin index is index of first generated decoder token
is_shortform=is_shortform,
num_beams=kwargs.get("num_beams", 1),
num_beams=generation_config.num_beams,
)

# 5. If we're in shortform mode, simple generate the whole input at once and return the output
if is_shortform:
if temperature is not None:
kwargs["temperature"] = temperature
generation_config.temperature = temperature

decoder_input_ids = kwargs.pop("decoder_input_ids", None)
if decoder_input_ids is None:
Expand All @@ -564,8 +563,8 @@ def generate(
[prompt_ids[None].repeat(decoder_input_ids.shape[0], 1), decoder_input_ids], dim=-1
)

if kwargs.get("max_new_tokens", 0) + decoder_input_ids.shape[-1] > self.config.max_target_positions:
max_new_tokens = kwargs.get("max_new_tokens", 0)
max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0
if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions:
raise ValueError(
f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
f"is {max_new_tokens}. Thus, the combined length of "
Expand Down Expand Up @@ -666,11 +665,10 @@ def generate(
)

# 6.6 set max new tokens or max length
kwargs = self._set_max_new_tokens_and_length(
self._set_max_new_tokens_and_length(
config=self.config,
decoder_input_ids=decoder_input_ids,
generation_config=generation_config,
kwargs=kwargs,
)

# 6.7 Set current `begin_index` for all logit processors
Expand Down Expand Up @@ -770,9 +768,9 @@ def generate_with_fallback(

for fallback_idx, temperature in enumerate(temperatures):
generation_config.do_sample = temperature is not None and temperature > 0.0

generation_config.temperature = temperature if generation_config.do_sample else 1.0
generation_config.num_beams = kwargs.get("num_beams", 1) if not generation_config.do_sample else 1
if generation_config.do_sample:
generation_config.num_beams = 1

generate_kwargs = copy.copy(kwargs)
for key in ["do_sample", "temperature", "num_beams"]:
Expand Down Expand Up @@ -1095,20 +1093,15 @@ def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
task = getattr(generation_config, "task", None)
language = getattr(generation_config, "language", None)

if kwargs.get("forced_decoder_ids", None) is not None:
forced_decoder_ids = kwargs["forced_decoder_ids"]
elif hasattr(generation_config, "forced_decoder_ids") and generation_config.forced_decoder_ids is not None:
forced_decoder_ids = generation_config.forced_decoder_ids

forced_decoder_ids = generation_config.forced_decoder_ids
if forced_decoder_ids is not None:
if language is None and task is None and forced_decoder_ids[0][1] is None:
logger.warning_once(
"Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
"This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
)
elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
forced_decoder_ids = config.forced_decoder_ids
else:
forced_decoder_ids = None

if forced_decoder_ids is not None and task is not None:
logger.info(
Expand Down Expand Up @@ -1288,21 +1281,6 @@ def _check_decoder_input_ids(kwargs):
"Passing `decoder_input_ids` is deprecated. Consider passing `prompt_ids` instead.",
)

@staticmethod
def _set_token_ids(generation_config, config, kwargs):
eos_token_id = kwargs.pop("eos_token_id", None)
decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)

eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else generation_config.decoder_start_token_id
)

generation_config.eos_token_id = eos_token_id if eos_token_id is not None else config.eos_token_id
generation_config.decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else config.decoder_start_token_id
)

@staticmethod
def _set_num_frames(return_token_timestamps, generation_config, kwargs):
if return_token_timestamps:
Expand All @@ -1313,7 +1291,6 @@ def _set_num_frames(return_token_timestamps, generation_config, kwargs):
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
)

generation_config.num_frames = kwargs.pop("num_frames", None)

@staticmethod
Expand Down Expand Up @@ -1517,47 +1494,21 @@ def _prepare_decoder_input_ids(
return decoder_input_ids, kwargs

@staticmethod
def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config, kwargs):
def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config):
num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1)

passed_max_length = kwargs.pop("max_length", None)
passed_max_new_tokens = kwargs.pop("max_new_tokens", None)
max_length_config = getattr(generation_config, "max_length", None)
max_new_tokens_config = getattr(generation_config, "max_new_tokens", None)

max_new_tokens = None
max_length = None

# Make sure we don't get larger than `max_length`
if passed_max_length is not None and passed_max_new_tokens is None:
max_length = min(passed_max_length + num_initial_tokens, config.max_target_positions)
logger.info(
f"Increase max_length from {passed_max_length} to {max_length} since input is conditioned on previous segment."
)
elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None:
if generation_config.max_length is not None and generation_config.max_new_tokens is None:
max_length = min(generation_config.max_length + num_initial_tokens, config.max_target_positions)
logger.info(
f"Increase max_length from {max_length_config} to {max_length} since input is conditioned on previous segment."
f"Increase max_length from {generation_config.max_length} to {max_length} since input is conditioned on previous segment."
)
elif (
passed_max_new_tokens is not None
and passed_max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions
generation_config.max_new_tokens is not None
and generation_config.max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions
):
max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]
elif (
passed_max_new_tokens is None
and max_new_tokens_config is not None
and max_new_tokens_config + decoder_input_ids.shape[-1] > config.max_target_positions
):
max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]

if max_new_tokens is not None:
kwargs["max_new_tokens"] = max_new_tokens

if max_length is not None:
kwargs["max_length"] = max_length

return kwargs
generation_config.max_new_tokens = max_new_tokens

@staticmethod
def _retrieve_compression_ratio(tokens, vocab_size):
Expand Down