diff --git a/src/transformers/generation/beam_search.py b/src/transformers/generation/beam_search.py index d22fbaf280de..402db5f24e1e 100644 --- a/src/transformers/generation/beam_search.py +++ b/src/transformers/generation/beam_search.py @@ -16,7 +16,7 @@ import warnings from abc import ABC, abstractmethod from collections import UserDict -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -212,7 +212,7 @@ def process( next_tokens: torch.LongTensor, next_indices: torch.LongTensor, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, beam_indices: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor]: cur_len = input_ids.shape[-1] @@ -234,6 +234,9 @@ def process( next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: if self.num_beams < len(beam_hyp): @@ -253,7 +256,7 @@ def process( ): batch_beam_idx = batch_idx * self.group_size + next_index # add to generated hypotheses if end of sentence - if (eos_token_id is not None) and (next_token.item() == eos_token_id): + if (eos_token_id is not None) and (next_token.item() in eos_token_id): # if beam_token does not belong to top num_beams tokens, it should not be added is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size if is_beam_token_worse_than_top_num_beams: @@ -307,11 +310,14 @@ def finalize( final_beam_indices: torch.LongTensor, max_length: int, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, beam_indices: Optional[torch.LongTensor] = None, ) -> Tuple[torch.LongTensor]: batch_size = len(self._beam_hyps) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: @@ -376,7 +382,8 @@ def finalize( indices[i, : len(best_idx)] = torch.tensor(best_idx) if sent_lengths[i] < sent_max_len: - decoded[i, sent_lengths[i]] = eos_token_id + # inserting only the first eos_token_id + decoded[i, sent_lengths[i]] = eos_token_id[0] return UserDict( { @@ -491,7 +498,7 @@ def process( next_indices: torch.LongTensor, scores_for_all_vocab: torch.FloatTensor, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, ) -> Tuple[torch.Tensor]: r""" Args: @@ -549,6 +556,9 @@ def process( next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: if self.num_beams < len(beam_hyp): @@ -568,7 +578,7 @@ def process( ): batch_beam_idx = batch_idx * self.group_size + next_index # add to generated hypotheses if end of sentence - if (eos_token_id is not None) and (next_token.item() == eos_token_id): + if (eos_token_id is not None) and (next_token.item() in eos_token_id): # if beam_token does not belong to top num_beams tokens, it should not be added is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size @@ -773,10 +783,13 @@ def finalize( final_beam_indices: torch.LongTensor, max_length: int, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, ) -> Tuple[torch.LongTensor]: batch_size = len(self._beam_hyps) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: @@ -840,7 +853,8 @@ def finalize( for i, hypo in enumerate(best): decoded[i, : sent_lengths[i]] = hypo if sent_lengths[i] < sent_max_len: - decoded[i, sent_lengths[i]] = eos_token_id + # inserting only the first eos_token_id + decoded[i, sent_lengths[i]] = eos_token_id[0] return UserDict( { diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 5f54008e16a8..721383c34e76 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -15,7 +15,7 @@ import inspect import math -from typing import Callable, Iterable, List, Optional, Tuple +from typing import Callable, Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -100,16 +100,18 @@ class MinLengthLogitsProcessor(LogitsProcessor): Args: min_length (`int`): The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`. - eos_token_id (`int`): + eos_token_id (`Union[int, List[int]]`): The id of the *end-of-sequence* token. """ - def __init__(self, min_length: int, eos_token_id: int): + def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]): if not isinstance(min_length, int) or min_length < 0: raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") - if not isinstance(eos_token_id, int) or eos_token_id < 0: - raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + if not all([isinstance(i, int) for i in eos_token_id]) or any([i < 0 for i in eos_token_id]): + raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") self.min_length = min_length self.eos_token_id = eos_token_id @@ -117,7 +119,8 @@ def __init__(self, min_length: int, eos_token_id: int): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: cur_len = input_ids.shape[-1] if cur_len < self.min_length: - scores[:, self.eos_token_id] = -float("inf") + for i in self.eos_token_id: + scores[:, i] = -float("inf") return scores @@ -395,11 +398,11 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): List of list of token ids that are not allowed to be generated. In order to get the token ids of the words that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids`. - eos_token_id (`int`): + eos_token_id (`Union[int, List[int]]`): The id of the *end-of-sequence* token. """ - def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int): + def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]): if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0: raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.") @@ -413,7 +416,10 @@ def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int): f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}." ) - bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + bad_words_ids = list(filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids)) self.bad_words_id_length_1 = [] self.bad_words_id_length_greater_than_1 = [] for word in bad_words_ids: @@ -628,20 +634,23 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): Args: max_length (`int`): The maximum length of the sequence to be generated. - eos_token_id (`int`): + eos_token_id (`Union[int, List[int]]`): The id of the token to force as the last generated token when `max_length` is reached. """ - def __init__(self, max_length: int, eos_token_id: int): + def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]): self.max_length = max_length + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] self.eos_token_id = eos_token_id def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: cur_len = input_ids.shape[-1] if cur_len == self.max_length - 1: num_tokens = scores.shape[1] - scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf") - scores[:, self.eos_token_id] = 0 + scores[:, [i for i in range(num_tokens) if i in self.eos_token_id]] = -float("inf") + for i in self.eos_token_id: + scores[:, i] = 0 return scores @@ -671,23 +680,26 @@ class ExponentialDecayLengthPenalty(LogitsProcessor): exponential_decay_length_penalty (`tuple(int, float)`, *optional*): This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay - eos_token_id (`int`): + eos_token_id (`Union[int, List[int]]`): The id of the *end-of-sequence* token. input_ids_seq_length (`int`): The length of the input sequence. """ - def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: int, input_ids_seq_length: int): + def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: Union[int, List[int]], input_ids_seq_length: int): self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length self.regulation_factor = exponential_decay_length_penalty[1] + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] self.eos_token_id = eos_token_id def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor: cur_len = input_ids.shape[-1] if cur_len > self.regulation_start: - scores[:, self.eos_token_id] = scores[:, self.eos_token_id] * pow( - self.regulation_factor, cur_len - self.regulation_start - ) + for i in self.eos_token_id: + scores[:, i] = scores[:, i] * pow( + self.regulation_factor, cur_len - self.regulation_start + ) return scores diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 89a036f43d8f..1c83854e946c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -569,11 +569,13 @@ def _prepare_attention_mask_for_generation( self, inputs: torch.Tensor, pad_token_id: Optional[int], - eos_token_id: Optional[int], + eos_token_id: Optional[Union[int, List[int]]], ) -> torch.LongTensor: is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) - is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id in eos_token_id) # Check if input is input_ids and padded -> only then is attention_mask defined if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: @@ -768,9 +770,9 @@ def _get_logits_processor( bad_words_ids: List[List[int]], min_length: int, max_length: int, - eos_token_id: int, + eos_token_id: Union[int, List[int]], forced_bos_token_id: int, - forced_eos_token_id: int, + forced_eos_token_id: Union[int, List[int]], prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int, num_beam_groups: int, @@ -801,6 +803,8 @@ def _get_logits_processor( ) bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty forced_bos_token_id = ( forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id @@ -919,7 +923,7 @@ def compute_transition_beam_scores( sequences: torch.Tensor, scores: Tuple[torch.Tensor], beam_indices: torch.Tensor, - eos_token_id: int = None, + eos_token_id: Union[int, List[int]] = None, ): """compute the transition probabilities of sequences given generation scores and beam indices""" @@ -1022,7 +1026,7 @@ def generate( force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, bos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, length_penalty: Optional[float] = None, no_repeat_ngram_size: Optional[int] = None, encoder_no_repeat_ngram_size: Optional[int] = None, @@ -1044,7 +1048,7 @@ def generate( output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, forced_bos_token_id: Optional[int] = None, - forced_eos_token_id: Optional[int] = None, + forced_eos_token_id: Optional[Union[int, List[int]]] = None, remove_invalid_values: Optional[bool] = None, synced_gpus: Optional[bool] = False, exponential_decay_length_penalty: Optional[Tuple[int, float]] = None, @@ -1325,14 +1329,18 @@ def generate( if eos_token_id is None and hasattr(self.config, "decoder"): eos_token_id = self.config.decoder.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + if pad_token_id is None and eos_token_id is not None: if model_kwargs.get("attention_mask", None) is None: logger.warning( "The attention mask and the pad token id were not set. As a consequence, you may observe " "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." ) - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - pad_token_id = eos_token_id + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id[0]} for open-end generation.") + # Setting the first eos_token_id + pad_token_id = eos_token_id[0] output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1816,7 +1824,7 @@ def contrastive_search( logits_warper: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -1900,6 +1908,8 @@ def contrastive_search( stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -2118,7 +2128,7 @@ def contrastive_search( # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens == i for i in eos_token_id)).long()) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): @@ -2155,7 +2165,7 @@ def greedy_search( stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -2255,6 +2265,8 @@ def greedy_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -2346,7 +2358,7 @@ def greedy_search( # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens == i for i in eos_token_id)).long()) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): @@ -2384,7 +2396,7 @@ def sample( logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -2503,6 +2515,8 @@ def sample( logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -2597,7 +2611,7 @@ def sample( # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens == i for i in eos_token_id)).long()) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): @@ -2635,7 +2649,7 @@ def beam_search( stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -2945,7 +2959,7 @@ def beam_sample( logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -3260,7 +3274,7 @@ def group_beam_search( stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -3384,6 +3398,8 @@ def group_beam_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -3622,7 +3638,7 @@ def constrained_beam_search( stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -3753,6 +3769,8 @@ def constrained_beam_search( warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (