From c0dc0232bc1b6ae8d49c99f42f477b195a712dd9 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 22 Nov 2022 15:55:30 +0000 Subject: [PATCH 01/20] generate from config mvp --- .../generation/configuration_utils.py | 68 +- src/transformers/generation/utils.py | 736 +++++++----------- 2 files changed, 330 insertions(+), 474 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 07a97c7f2522..0ac1a68a80d9 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -20,6 +20,7 @@ from typing import Any, Dict, Optional, Union from .. import __version__ +from ..configuration_utils import PretrainedConfig from ..utils import ( GENERATION_CONFIG_NAME, PushToHubMixin, @@ -73,6 +74,9 @@ class GenerationConfig(PushToHubMixin): [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. penalty_alpha (`float`, *optional*): The values balance the model confidence and the degeneration penalty in contrastive search decoding. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should use the past last key/values attentions (if applicable to the model) to + speed up decoding. > Parameters for manipulation of the model output logits @@ -108,13 +112,13 @@ class GenerationConfig(PushToHubMixin): words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one can allow different forms of each word. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should use the past last key/values attentions (if applicable to the model) to - speed up decoding. renormalize_logits (`bool`, *optional*, defaults to `False`): Whether to renormalize the logits after applying all the logits processors or warpers (including the custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits are normalized but some logit processors or warpers break the normalization. + constraints (`List[Constraint]`, *optional*): + Custom constraints that can be added to the generation to ensure that the output will contain the use of + certain tokens as defined by `Constraint` objects, in the most sensible way possible. forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`): The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target @@ -191,6 +195,7 @@ def __init__(self, **kwargs): self.num_beams = kwargs.pop("num_beams", 1) self.num_beam_groups = kwargs.pop("num_beam_groups", 1) self.penalty_alpha = kwargs.pop("penalty_alpha", None) + self.use_cache = kwargs.pop("use_cache", True) # Parameters for manipulation of the model output logits self.temperature = kwargs.pop("temperature", 1.0) @@ -203,6 +208,8 @@ def __init__(self, **kwargs): self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) self.bad_words_ids = kwargs.pop("bad_words_ids", None) self.force_word_ids = kwargs.pop("force_word_ids", None) + self.renormalize_logits = kwargs.pop("renormalize_logits", False) + self.constraints = kwargs.pop("constraints", None) self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None) self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None) self.remove_invalid_values = kwargs.pop("remove_invalid_values", False) @@ -484,18 +491,11 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig": kwargs["_commit_hash"] = config_dict["_commit_hash"] config = cls(**config_dict) - - to_remove = [] - for key, value in kwargs.items(): - if hasattr(config, key): - setattr(config, key, value) - to_remove.append(key) - for key in to_remove: - kwargs.pop(key, None) + unused_kwargs = config.update(kwargs) logger.info(f"Generate config {config}") if return_unused_kwargs: - return config, kwargs + return config, unused_kwargs else: return config @@ -568,3 +568,47 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = """ with open(json_file_path, "w", encoding="utf-8") as writer: writer.write(self.to_json_string(use_diff=use_diff)) + + @classmethod + def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig": + """ + Instantiates a [`GenerationConfig`] from a [`PretrainedConfig`]. This function is useful to convert legacy + [`PretrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`]. + + Args: + model_config (`PretrainedConfig`): + The model config that will be used to instantiate the generation config. + + Returns: + [`GenerationConfig`]: The configuration object instantiated from those parameters. + """ + config_dict = model_config.to_dict() + config = cls.from_dict(config_dict, return_unused_kwargs=False) + + # Handles a few special cases + if config.eos_token_id is None and hasattr(config_dict, "decoder"): + config.eos_token_id = config_dict["decoder"]["eos_token_id"] + + return config + + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes, + returning all the unused kwargs. + + Args: + kwargs (`Dict[str, Any]`): + Dictionary of attributes to tentatively update this class. + + Returns: + `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3d945b2be37a..338a710de840 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -14,10 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import warnings from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -33,8 +34,9 @@ ) from ..pytorch_utils import torch_int_div from ..utils import ModelOutput, logging -from .beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint +from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer +from .configuration_utils import GenerationConfig from .logits_process import ( EncoderNoRepeatNGramLogitsProcessor, ExponentialDecayLengthPenalty, @@ -722,166 +724,150 @@ def _reorder_cache(self, past, beam_idx): def _get_logits_warper( self, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - typical_p: Optional[float] = None, - temperature: Optional[float] = None, - num_beams: Optional[int] = None, - renormalize_logits: Optional[bool] = None, + generation_config: GenerationConfig, ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances used for multinomial sampling. """ - # init warp parameters - top_k = top_k if top_k is not None else self.config.top_k - top_p = top_p if top_p is not None else self.config.top_p - typical_p = typical_p if typical_p is not None else self.config.typical_p - temperature = temperature if temperature is not None else self.config.temperature # instantiate warpers list warpers = LogitsProcessorList() # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` - if temperature is not None and temperature != 1.0: - warpers.append(TemperatureLogitsWarper(temperature)) - if top_k is not None and top_k != 0: - warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) - if top_p is not None and top_p < 1.0: - warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) - if typical_p is not None and typical_p < 1.0: - warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + if generation_config.temperature is not None and generation_config.temperature != 1.0: + warpers.append(TemperatureLogitsWarper(generation_config.temperature)) + if generation_config.top_k is not None and generation_config.top_k != 0: + warpers.append( + TopKLogitsWarper( + top_k=generation_config.top_k, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1) + ) + ) + if generation_config.top_p is not None and generation_config.top_p < 1.0: + warpers.append( + TopPLogitsWarper( + top_p=generation_config.top_p, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1) + ) + ) + if generation_config.typical_p is not None and generation_config.typical_p < 1.0: + warpers.append( + TypicalLogitsWarper( + mass=generation_config.typical_p, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1) + ) + ) # `LogitNormalization` should always be the last logit processor, when present - if renormalize_logits is True: + if generation_config.renormalize_logits is True: warpers.append(LogitNormalization()) return warpers def _get_logits_processor( self, - repetition_penalty: float, - no_repeat_ngram_size: int, - encoder_no_repeat_ngram_size: int, + generation_config: GenerationConfig, input_ids_seq_length: int, encoder_input_ids: torch.LongTensor, - bad_words_ids: List[List[int]], - min_length: int, - max_length: int, - eos_token_id: int, - forced_bos_token_id: int, - forced_eos_token_id: int, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], - num_beams: int, - num_beam_groups: int, - diversity_penalty: float, - remove_invalid_values: bool, - exponential_decay_length_penalty: Tuple, logits_processor: Optional[LogitsProcessorList], - renormalize_logits: Optional[bool], - suppress_tokens: Optional[List[int]] = None, - begin_suppress_tokens: Optional[List[int]] = None, - forced_decoder_ids: Optional[List[List[int]]] = None, ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] instances used to modify the scores of the language model head. """ - processors = LogitsProcessorList() - - # init warp parameters - repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty - no_repeat_ngram_size = ( - no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size - ) - encoder_no_repeat_ngram_size = ( - encoder_no_repeat_ngram_size - if encoder_no_repeat_ngram_size is not None - else self.config.encoder_no_repeat_ngram_size - ) - bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty - forced_bos_token_id = ( - forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id - ) - forced_eos_token_id = ( - forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id - ) - remove_invalid_values = ( - remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values - ) - exponential_decay_length_penalty = ( - exponential_decay_length_penalty - if exponential_decay_length_penalty is not None - else self.config.exponential_decay_length_penalty - ) - suppress_tokens = suppress_tokens if suppress_tokens is not None else self.config.suppress_tokens - begin_suppress_tokens = ( - begin_suppress_tokens if begin_suppress_tokens is not None else self.config.begin_suppress_tokens - ) - if forced_decoder_ids is None and hasattr(self.config, "forced_decoder_ids"): - forced_decoder_ids = self.config.forced_decoder_ids # instantiate processors list + processors = LogitsProcessorList() # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` - if diversity_penalty is not None and diversity_penalty > 0.0: + if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0: processors.append( HammingDiversityLogitsProcessor( - diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups + diversity_penalty=generation_config.diversity_penalty, + num_beams=generation_config.num_beams, + num_beam_groups=generation_config.num_beam_groups, ) ) - if repetition_penalty is not None and repetition_penalty != 1.0: - processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) - if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: - processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) - if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0: + if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: + processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty)) + if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0: + processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size)) + if ( + generation_config.encoder_no_repeat_ngram_size is not None + and generation_config.encoder_no_repeat_ngram_size > 0 + ): if self.config.is_encoder_decoder: - processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids)) + processors.append( + EncoderNoRepeatNGramLogitsProcessor( + generation_config.encoder_no_repeat_ngram_size, encoder_input_ids + ) + ) else: raise ValueError( "It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture" ) - if bad_words_ids is not None: - processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) - if min_length is not None and eos_token_id is not None and min_length > 0: - processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) + if generation_config.bad_words_ids is not None: + processors.append( + NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id) + ) + if ( + generation_config.min_length is not None + and generation_config.eos_token_id is not None + and generation_config.min_length > 0 + ): + processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)) if prefix_allowed_tokens_fn is not None: - processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups)) - if forced_bos_token_id is not None: - processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) - if forced_eos_token_id is not None: - processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) - if remove_invalid_values is True: + processors.append( + PrefixConstrainedLogitsProcessor( + prefix_allowed_tokens_fn, generation_config.num_beams // generation_config.num_beam_groups + ) + ) + if generation_config.forced_bos_token_id is not None: + processors.append(ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id)) + if generation_config.forced_eos_token_id is not None: + processors.append( + ForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id) + ) + if generation_config.remove_invalid_values is True: processors.append(InfNanRemoveLogitsProcessor()) - if exponential_decay_length_penalty is not None: + if generation_config.exponential_decay_length_penalty is not None: processors.append( - ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length) + ExponentialDecayLengthPenalty( + generation_config.exponential_decay_length_penalty, + generation_config.eos_token_id, + generation_config.input_ids_seq_length, + ) ) - if suppress_tokens is not None: - processors.append(SuppressTokensLogitsProcessor(suppress_tokens)) - if begin_suppress_tokens is not None: + if generation_config.suppress_tokens is not None: + processors.append(SuppressTokensLogitsProcessor(generation_config.suppress_tokens)) + if generation_config.begin_suppress_tokens is not None: begin_index = input_ids_seq_length - begin_index = begin_index if (input_ids_seq_length > 1 or forced_bos_token_id is None) else begin_index + 1 - if forced_decoder_ids is not None: - begin_index += forced_decoder_ids[-1][0] # generation starts after the last token that is forced - processors.append(SuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index)) - if forced_decoder_ids is not None: - processors.append(ForceTokensLogitsProcessor(forced_decoder_ids)) + begin_index = ( + begin_index + if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) + else begin_index + 1 + ) + if generation_config.forced_decoder_ids is not None: + begin_index += generation_config.forced_decoder_ids[-1][ + 0 + ] # generation starts after the last token that is forced + processors.append( + SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) + ) + if generation_config.forced_decoder_ids is not None: + processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)) processors = self._merge_criteria_processor_list(processors, logits_processor) # `LogitNormalization` should always be the last logit processor, when present - if renormalize_logits is True: + if generation_config.renormalize_logits is True: processors.append(LogitNormalization()) return processors def _get_stopping_criteria( - self, max_length: Optional[int], max_time: Optional[float], stopping_criteria: Optional[StoppingCriteriaList] + self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList] ) -> StoppingCriteriaList: criteria = StoppingCriteriaList() - if max_length is not None: - criteria.append(MaxLengthCriteria(max_length=max_length)) - if max_time is not None: - criteria.append(MaxTimeCriteria(max_time=max_time)) + if generation_config.max_length is not None: + criteria.append(MaxLengthCriteria(max_length=generation_config.max_length)) + if generation_config.max_time is not None: + criteria.append(MaxTimeCriteria(max_time=generation_config.max_time)) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) return criteria @@ -999,50 +985,12 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): def generate( self, inputs: Optional[torch.Tensor] = None, - max_length: Optional[int] = None, - min_length: Optional[int] = None, - do_sample: Optional[bool] = None, - early_stopping: Optional[bool] = None, - num_beams: Optional[int] = None, - temperature: Optional[float] = None, - penalty_alpha: Optional[float] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - typical_p: Optional[float] = None, - repetition_penalty: Optional[float] = None, - bad_words_ids: Optional[Iterable[int]] = None, - force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, - bos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - length_penalty: Optional[float] = None, - no_repeat_ngram_size: Optional[int] = None, - encoder_no_repeat_ngram_size: Optional[int] = None, - num_return_sequences: Optional[int] = None, - max_time: Optional[float] = None, - max_new_tokens: Optional[int] = None, - decoder_start_token_id: Optional[int] = None, - use_cache: Optional[bool] = None, - num_beam_groups: Optional[int] = None, - diversity_penalty: Optional[float] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, - renormalize_logits: Optional[bool] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, - constraints: Optional[List[Constraint]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - forced_bos_token_id: Optional[int] = None, - forced_eos_token_id: Optional[int] = None, - remove_invalid_values: Optional[bool] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, synced_gpus: Optional[bool] = False, - exponential_decay_length_penalty: Optional[Tuple[int, float]] = None, - suppress_tokens: Optional[List[int]] = None, - begin_suppress_tokens: Optional[List[int]] = None, - forced_decoder_ids: Optional[List[List[int]]] = None, - **model_kwargs, + **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: r""" @@ -1066,9 +1014,11 @@ def generate( - Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as - defined in the model's config (`config.json`) which in turn defaults to the - [`~modeling_utils.PretrainedConfig`] of the model. + Apart from `inputs` and the other keyword arguments, all generation-controlling parameters will default to the + value of the corresponding attribute in `generation_config`. You can override these defaults by passing the + corresponding parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + You can view the full list of attributes in [`~generation.GenerationConfig`]. @@ -1081,81 +1031,21 @@ def generate( method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of `input_ids`, `input_values`, `input_features`, or `pixel_values`. - max_length (`int`, *optional*, defaults to `model.config.max_length`): - The maximum length the generated tokens can have. Corresponds to the length of the input prompt + - `max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in - the prompt. - max_new_tokens (`int`, *optional*): - The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. - min_length (`int`, *optional*, defaults to `model.config.min_length` or 10 if the config does not set any value): - The minimum length of the sequence to be generated. - do_sample (`bool`, *optional*, defaults to `model.config.do_sample` or `False` if the config does not set any value): - Whether or not to use sampling ; use greedy decoding otherwise. - early_stopping (`bool`, *optional*, defaults to `False`): - Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. - num_beams (`int`, *optional*, defaults to `model.config.num_beams` or 1 if the config does not set any value): - Number of beams for beam search. 1 means no beam search. - temperature (`float`, *optional*, defaults to `model.config.temperature` or 1.0 if the config does not set any value): - The value used to module the next token probabilities. - penalty_alpha (`float`, *optional*, defaults to `model.config.penalty_alpha` or None if the config does not set any value): - The values balance the model confidence and the degeneration penalty in contrastive search decoding. - top_k (`int`, *optional*, defaults to `model.config.top_k` or 50 if the config does not set any value): - The number of highest probability vocabulary tokens to keep for top-k-filtering. - top_p (`float`, *optional*, defaults to `model.config.top_p` or 1.0 if the config does not set any value): - If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to - `top_p` or higher are kept for generation. - typical_p (`float`, *optional*, defaults to `model.config.typical_p` or 1.0 if the config does not set any value): - The amount of probability mass from the original distribution to be considered in typical decoding. If - set to 1.0 it takes no effect. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details. - repetition_penalty (`float`, *optional*, defaults to `model.config.repetition_penalty` or 1.0 if the config does not set any value): - The parameter for repetition penalty. 1.0 means no penalty. See [this - paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. - pad_token_id (`int`, *optional*, defaults to `model.config.pad_token_id`): - The id of the *padding* token. - bos_token_id (`int`, *optional*, defaults to `model.config.bos_token_id`): - The id of the *beginning-of-sequence* token. - eos_token_id (`int`, *optional*, defaults to `model.config.eos_token_id`): - The id of the *end-of-sequence* token. - length_penalty (`float`, *optional*, defaults to `model.config.length_penalty` or 1.0 if the config does not set any value): - Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent - to the sequence length, which in turn is used to divide the score of the sequence. Since the score is - the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, - while `length_penalty` < 0.0 encourages shorter sequences. - no_repeat_ngram_size (`int`, *optional*, defaults to `model.config.no_repeat_ngram_size` or 0 if the config does not set any value): - If set to int > 0, all ngrams of that size can only occur once. - encoder_no_repeat_ngram_size (`int`, *optional*, defaults to `model.config.encoder_no_repeat_ngram_size` or 0 if the config does not set any value): - If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the - `decoder_input_ids`. - bad_words_ids(`List[List[int]]`, *optional*, defaults to `model.config.bad_words_ids`): - List of token ids that are not allowed to be generated. In order to get the token ids of the words that - should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, - add_special_tokens=False).input_ids`. - force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*): - List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple - list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, - this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), - where one can allow different forms of each word. - num_return_sequences(`int`, *optional*, defaults to `model.config.num_return_sequences` or 1 if the config does not set any value): - The number of independently computed returned sequences for each element in the batch. - max_time(`float`, *optional*): - The maximum amount of time you allow the computation to run for in seconds. generation will still - finish the current pass after allocated time has been passed. - attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens - that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape - as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask) - decoder_start_token_id (`int`, *optional*): - If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should use the past last key/values attentions (if applicable to the model) to - speed up decoding. - num_beam_groups (`int`, *optional*, defaults to `model.config.num_beam_groups` or 1 if the config does not set any value): - Number of groups to divide `num_beams` into in order to ensure diversity among different groups of - beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. - diversity_penalty (`float`, *optional*, defaults to `model.config.diversity_penalty` or 0.0 if the config does not set any value): - This value is subtracted from a beam's score if it generates a token same as any beam from other group - at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is - enabled. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, a default will be used, with the following loading priority: 1) + from the `generation_config.json` model file, if it exists; 2) from the `self.config` model attribute, + containing a model configuration. Please note that unspecified parameters will inherit + [`~generation.GenerationConfig`]'s default values. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): If provided, this function constraints the beam search to allowed tokens only at each step. If not provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and @@ -1163,60 +1053,12 @@ def generate( on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful for constrained generation conditioned on the prefix, as described in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904). - logits_processor (`LogitsProcessorList`, *optional*): - Custom logits processors that complement the default logits processors built from arguments and a - model's config. If a logit processor is passed that is already created with the arguments or a model's - config an error is thrown. This feature is intended for advanced users. - renormalize_logits (`bool`, *optional*, defaults to `False`): - Whether to renormalize the logits after applying all the logits processors or warpers (including the - custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the - score logits are normalized but some logit processors or warpers break the normalization. - stopping_criteria (`StoppingCriteriaList`, *optional*): - Custom stopping criteria that complement the default stopping criteria built from arguments and a - model's config. If a stopping criteria is passed that is already created with the arguments or a - model's config an error is thrown. This feature is intended for advanced users. - constraints (`List[Constraint]`, *optional*): - Custom constraints that can be added to the generation to ensure that the output will contain the use - of certain tokens as defined by `Constraint` objects, in the most sensible way possible. - output_attentions (`bool`, *optional*, defaults to `model.config.output_attentions` or `False` if the config does not set any value): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `model.config.output_hidden_states` or `False` if the config does not set any value): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `model.config.output_scores` or `False` if the config does not set any value): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `model.config.return_dict_in_generate` or `False` if the config does not set any value): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`): - The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful - for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be - the target language token. - forced_eos_token_id (`int`, *optional*, defaults to `model.config.forced_eos_token_id`): - The id of the token to force as the last generated token when `max_length` is reached. - remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`): - Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to - crash. Note that using `remove_invalid_values` can slow down generation. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - exponential_decay_length_penalty (`tuple(int, float)`, *optional*, defaults to `model.config.exponential_decay_length_penalty`): - This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been - generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates - where penalty starts and `decay_factor` represents the factor of exponential decay - suppress_tokens (`List[int]`, *optional*, defaults to `model.config.suppress_tokens`): - A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set - their log probs to `-inf` so that they are not sampled. - begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`): - A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens` - logit processor will set their log probs to `-inf` so that they are not sampled. - forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`): - A list of pairs of integers which indicates a mapping from generation indices to token indices that - will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always - be a token of index 123. - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model - is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs - should be prefixed with *decoder_*. + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. Return: [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` @@ -1291,71 +1133,63 @@ def generate( >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] ```""" - # 0. Validate the `.generate()` call + # 1. Handle `generation_config` and kwargs that might update it + if generation_config is None: + try: + generation_config = GenerationConfig.from_pretrained(self.config.name_or_path) + except EnvironmentError: + generation_config = GenerationConfig.from_model_config(self.config) + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(kwargs) # All unused kwargs must be model kwargs + + # 2. Validate the `.generate()` call self._validate_model_class() self._validate_model_kwargs(model_kwargs.copy()) - # 1. Set generation parameters if not already defined - bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id - num_beams = num_beams if num_beams is not None else self.config.num_beams - length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty - early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping - num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups - do_sample = do_sample if do_sample is not None else self.config.do_sample - num_return_sequences = ( - num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences - ) + # 3. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - - if eos_token_id is None and hasattr(self.config, "decoder"): - eos_token_id = self.config.decoder.eos_token_id - - if pad_token_id is None and eos_token_id is not None: + if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: if model_kwargs.get("attention_mask", None) is None: logger.warning( "The attention mask and the pad token id were not set. As a consequence, you may observe " "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." ) - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - pad_token_id = eos_token_id - - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate - ) + logger.warning( + f"Setting `pad_token_id` to `eos_token_id`:{generation_config.eos_token_id} for open-end generation." + ) + generation_config.pad_token_id = generation_config.eos_token_id - # 2. Define model inputs + # 4. Define model inputs # inputs_tensor has to be defined # model_input_name is defined if model-specific keyword input is passed # otherwise model_input_name is None # all model-specific keyword inputs are removed from `model_kwargs` - inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs) + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) batch_size = inputs_tensor.shape[0] - # 3. Define other model kwargs - model_kwargs["output_attentions"] = output_attentions - model_kwargs["output_hidden_states"] = output_hidden_states - model_kwargs["use_cache"] = use_cache + # 5. Define other model kwargs + model_kwargs["output_attentions"] = generation_config.output_attentions + model_kwargs["output_hidden_states"] = generation_config.output_hidden_states + model_kwargs["use_cache"] = generation_config.use_cache accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, pad_token_id, eos_token_id + inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id ) # decoder-only models should use left-padding for generation if not self.config.is_encoder_decoder: - if pad_token_id is not None and torch.sum(inputs_tensor[:, -1] == pad_token_id) > 0: + if ( + generation_config.pad_token_id is not None + and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 + ): logger.warning( "A decoder-only architecture is being used, but right-padding was detected! For correct " "generation results, please set `padding_side='left'` when initializing the tokenizer." @@ -1368,12 +1202,12 @@ def generate( inputs_tensor, model_kwargs, model_input_name ) - # 4. Prepare `input_ids` which will be used for auto-regressive generation + # 6. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: input_ids = self._prepare_decoder_input_ids_for_generation( batch_size, - decoder_start_token_id=decoder_start_token_id, - bos_token_id=bos_token_id, + decoder_start_token_id=generation_config.decoder_start_token_id, + bos_token_id=generation_config.bos_token_id, model_kwargs=model_kwargs, device=inputs_tensor.device, ) @@ -1381,87 +1215,91 @@ def generate( # if decoder-only then inputs_tensor has to be `input_ids` input_ids = inputs_tensor - # 5. Prepare `max_length` depending on other stopping criteria. + # 7. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] - if max_length is None and max_new_tokens is None: + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length == 20 + if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( - "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to " - f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is " - "deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend " - "using `max_new_tokens` to control the maximum length of the generation.", + "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to" + f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" + " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) - elif max_length is None and max_new_tokens is not None: - max_length = max_new_tokens + input_ids_seq_length - elif max_length is not None and max_new_tokens is not None: + elif has_default_max_length and generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + elif not has_default_max_length and generation_config.max_new_tokens is not None: raise ValueError( "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" " limit to the generated output length. Remove one of those arguments. Please refer to the" " documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) - # default to config if still None - max_length = max_length if max_length is not None else self.config.max_length - min_length = min_length if min_length is not None else self.config.min_length - if min_length is not None and min_length > max_length: + if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError( - f"Unfeasible length constraints: the minimum length ({min_length}) is larger than the maximum " - f"length ({max_length})" + f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" + f" the maximum length ({generation_config.max_length})" ) - if input_ids_seq_length >= max_length: + if input_ids_seq_length >= generation_config.max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" logger.warning( f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {max_length}. This can lead to unexpected behavior. You should consider increasing " - "`max_new_tokens`." + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." ) # 6. determine generation mode - is_constraint_gen_mode = constraints is not None or force_words_ids is not None + is_constraint_gen_mode = ( + generation_config.constraints is not None or generation_config.force_words_ids is not None + ) is_contrastive_search_gen_mode = ( - top_k is not None and top_k > 1 and do_sample is False and penalty_alpha is not None and penalty_alpha > 0 + generation_config.top_k is not None + and generation_config.top_k > 1 + and generation_config.do_sample is False + and generation_config.penalty_alpha is not None + and generation_config.penalty_alpha > 0 ) is_greedy_gen_mode = ( - (num_beams == 1) - and (num_beam_groups == 1) - and do_sample is False + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_sample_gen_mode = ( - (num_beams == 1) - and (num_beam_groups == 1) - and do_sample is True + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_beam_gen_mode = ( - (num_beams > 1) - and (num_beam_groups == 1) - and do_sample is False + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_beam_sample_gen_mode = ( - (num_beams > 1) - and (num_beam_groups == 1) - and do_sample is True + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_group_beam_gen_mode = ( - (num_beams > 1) - and (num_beam_groups > 1) + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups > 1) and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) - if num_beam_groups > num_beams: + if generation_config.num_beam_groups > generation_config.num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") - if is_group_beam_gen_mode and do_sample is True: + if is_group_beam_gen_mode and generation_config.do_sample is True: raise ValueError( "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." ) @@ -1479,39 +1317,23 @@ def generate( # 7. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, + generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, encoder_input_ids=inputs_tensor, - bad_words_ids=bad_words_ids, - min_length=min_length, - max_length=max_length, - eos_token_id=eos_token_id, - forced_bos_token_id=forced_bos_token_id, - forced_eos_token_id=forced_eos_token_id, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - num_beams=num_beams, - num_beam_groups=num_beam_groups, - diversity_penalty=diversity_penalty, - remove_invalid_values=remove_invalid_values, - exponential_decay_length_penalty=exponential_decay_length_penalty, logits_processor=logits_processor, - renormalize_logits=renormalize_logits, - suppress_tokens=suppress_tokens, - begin_suppress_tokens=begin_suppress_tokens, - forced_decoder_ids=forced_decoder_ids, ) # 8. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( - max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria + generation_config=generation_config, stopping_criteria=stopping_criteria ) # 9. go into different generation modes if is_greedy_gen_mode: - if num_return_sequences > 1: + if generation_config.num_return_sequences > 1: raise ValueError( - f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " greedy search." ) # 10. run greedy search @@ -1519,50 +1341,44 @@ def generate( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_contrastive_search_gen_mode: - if num_return_sequences > 1: + if generation_config.num_return_sequences > 1: raise ValueError( - f"num_return_sequences has to be 1, but is {num_return_sequences} when doing contrastive search." + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " contrastive search." ) return self.contrastive_search( input_ids, - top_k=top_k, - penalty_alpha=penalty_alpha, + top_k=generation_config.top_k, + penalty_alpha=generation_config.penalty_alpha, logits_processor=logits_processor, stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_sample_gen_mode: # 10. prepare logits warper - logits_warper = self._get_logits_warper( - top_k=top_k, - top_p=top_p, - typical_p=typical_p, - temperature=temperature, - num_beams=num_beams, - renormalize_logits=renormalize_logits, - ) + logits_warper = self._get_logits_warper(generation_config) # 11. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, - expand_size=num_return_sequences, + expand_size=generation_config.num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) @@ -1573,16 +1389,16 @@ def generate( logits_processor=logits_processor, logits_warper=logits_warper, stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_beam_gen_mode: - if num_return_sequences > num_beams: + if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if stopping_criteria.max_length is None: @@ -1591,16 +1407,16 @@ def generate( # 10. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size, - num_beams=num_beams, + num_beams=generation_config.num_beams, device=inputs_tensor.device, - length_penalty=length_penalty, - do_early_stopping=early_stopping, - num_beam_hyps_to_keep=num_return_sequences, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, ) # 11. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, - expand_size=num_beams, + expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) @@ -1610,40 +1426,33 @@ def generate( beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_beam_sample_gen_mode: # 10. prepare logits warper - logits_warper = self._get_logits_warper( - top_k=top_k, - top_p=top_p, - typical_p=typical_p, - temperature=temperature, - num_beams=num_beams, - renormalize_logits=renormalize_logits, - ) + logits_warper = self._get_logits_warper(generation_config) if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") # 11. prepare beam search scorer beam_scorer = BeamSearchScorer( - batch_size=batch_size * num_return_sequences, - num_beams=num_beams, + batch_size=batch_size * generation_config.num_return_sequences, + num_beams=generation_config.num_beams, device=inputs_tensor.device, - length_penalty=length_penalty, - do_early_stopping=early_stopping, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, ) # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, - expand_size=num_beams * num_return_sequences, + expand_size=generation_config.num_beams * generation_config.num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) @@ -1655,42 +1464,42 @@ def generate( logits_processor=logits_processor, logits_warper=logits_warper, stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_group_beam_gen_mode: - if num_return_sequences > num_beams: + if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") - if num_beams % num_beam_groups != 0: + if generation_config.num_beams % generation_config.num_beam_groups != 0: raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") - if typical_p is not None: + if generation_config.typical_p is not None: raise ValueError("Decoder argument `typical_p` is not supported with beam groups.") # 10. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size, - num_beams=num_beams, + num_beams=generation_config.num_beams, max_length=stopping_criteria.max_length, device=inputs_tensor.device, - length_penalty=length_penalty, - do_early_stopping=early_stopping, - num_beam_hyps_to_keep=num_return_sequences, - num_beam_groups=num_beam_groups, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + num_beam_groups=generation_config.num_beam_groups, ) # 11. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, - expand_size=num_beams, + expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) @@ -1700,46 +1509,49 @@ def generate( beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_constraint_gen_mode: - if num_return_sequences > num_beams: + if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") - if num_beams <= 1: + if generation_config.num_beams <= 1: raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.") - if do_sample: + if generation_config.do_sample: raise ValueError("`do_sample` needs to be false for constrained generation.") - if num_beam_groups is not None and num_beam_groups > 1: + if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1: raise ValueError("`num_beam_groups` not supported yet for constrained generation.") final_constraints = [] - if constraints is not None: - final_constraints = constraints + if generation_config.constraints is not None: + final_constraints = generation_config.constraints - if force_words_ids is not None: + if generation_config.force_words_ids is not None: def typeerror(): raise ValueError( "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" - f"of positive integers, but is {force_words_ids}." + f"of positive integers, but is {generation_config.force_words_ids}." ) - if not isinstance(force_words_ids, list) or len(force_words_ids) == 0: + if ( + not isinstance(generation_config.force_words_ids, list) + or len(generation_config.force_words_ids) == 0 + ): typeerror() - for word_ids in force_words_ids: + for word_ids in generation_config.force_words_ids: if isinstance(word_ids[0], list): if not isinstance(word_ids, list) or len(word_ids) == 0: typeerror() @@ -1765,16 +1577,16 @@ def typeerror(): constrained_beam_scorer = ConstrainedBeamSearchScorer( constraints=final_constraints, batch_size=batch_size, - num_beams=num_beams, + num_beams=generation_config.num_beams, device=inputs_tensor.device, - length_penalty=length_penalty, - do_early_stopping=early_stopping, - num_beam_hyps_to_keep=num_return_sequences, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, ) # 11. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, - expand_size=num_beams, + expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) @@ -1784,10 +1596,10 @@ def typeerror(): constrained_beam_scorer=constrained_beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) From d9de3dfc4d396402c3a4e290afa4e3ddac0745da Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 22 Nov 2022 16:21:25 +0000 Subject: [PATCH 02/20] fix failing tests --- .../generation/configuration_utils.py | 4 +- src/transformers/generation/utils.py | 47 ++++++++++--------- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 0ac1a68a80d9..c8ffad4295e4 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -207,7 +207,7 @@ def __init__(self, **kwargs): self.length_penalty = kwargs.pop("length_penalty", 1.0) self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) self.bad_words_ids = kwargs.pop("bad_words_ids", None) - self.force_word_ids = kwargs.pop("force_word_ids", None) + self.force_words_ids = kwargs.pop("force_words_ids", None) self.renormalize_logits = kwargs.pop("renormalize_logits", False) self.constraints = kwargs.pop("constraints", None) self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None) @@ -491,7 +491,7 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig": kwargs["_commit_hash"] = config_dict["_commit_hash"] config = cls(**config_dict) - unused_kwargs = config.update(kwargs) + unused_kwargs = config.update(**kwargs) logger.info(f"Generate config {config}") if return_unused_kwargs: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 338a710de840..fe442cef2fcb 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1140,7 +1140,7 @@ def generate( except EnvironmentError: generation_config = GenerationConfig.from_model_config(self.config) generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(kwargs) # All unused kwargs must be model kwargs + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs # 2. Validate the `.generate()` call self._validate_model_class() @@ -1249,7 +1249,7 @@ def generate( " increasing `max_new_tokens`." ) - # 6. determine generation mode + # 8. determine generation mode is_constraint_gen_mode = ( generation_config.constraints is not None or generation_config.force_words_ids is not None ) @@ -1315,7 +1315,7 @@ def generate( UserWarning, ) - # 7. prepare distribution pre_processing samplers + # 9. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, @@ -1324,11 +1324,11 @@ def generate( logits_processor=logits_processor, ) - # 8. prepare stopping criteria + # 10. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) - # 9. go into different generation modes + # 11. go into different generation modes if is_greedy_gen_mode: if generation_config.num_return_sequences > 1: raise ValueError( @@ -1336,7 +1336,7 @@ def generate( " greedy search." ) - # 10. run greedy search + # 12. run greedy search return self.greedy_search( input_ids, logits_processor=logits_processor, @@ -1372,10 +1372,10 @@ def generate( ) elif is_sample_gen_mode: - # 10. prepare logits warper + # 12. prepare logits warper logits_warper = self._get_logits_warper(generation_config) - # 11. expand input_ids with `num_return_sequences` additional sequences per batch + # 13. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_return_sequences, @@ -1383,7 +1383,7 @@ def generate( **model_kwargs, ) - # 12. run sample + # 14. run sample return self.sample( input_ids, logits_processor=logits_processor, @@ -1404,7 +1404,7 @@ def generate( if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") - # 10. prepare beam search scorer + # 12. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=generation_config.num_beams, @@ -1413,14 +1413,14 @@ def generate( do_early_stopping=generation_config.early_stopping, num_beam_hyps_to_keep=generation_config.num_return_sequences, ) - # 11. interleave input_ids with `num_beams` additional sequences per batch + # 13. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) - # 12. run beam search + # 14. run beam search return self.beam_search( input_ids, beam_scorer, @@ -1435,12 +1435,12 @@ def generate( ) elif is_beam_sample_gen_mode: - # 10. prepare logits warper + # 12. prepare logits warper logits_warper = self._get_logits_warper(generation_config) if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") - # 11. prepare beam search scorer + # 13. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size * generation_config.num_return_sequences, num_beams=generation_config.num_beams, @@ -1449,7 +1449,7 @@ def generate( do_early_stopping=generation_config.early_stopping, ) - # 12. interleave input_ids with `num_beams` additional sequences per batch + # 14. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams * generation_config.num_return_sequences, @@ -1457,7 +1457,7 @@ def generate( **model_kwargs, ) - # 13. run beam sample + # 15. run beam sample return self.beam_sample( input_ids, beam_scorer, @@ -1482,10 +1482,11 @@ def generate( if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") - if generation_config.typical_p is not None: + has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0 + if not has_default_typical_p: raise ValueError("Decoder argument `typical_p` is not supported with beam groups.") - # 10. prepare beam search scorer + # 12. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=generation_config.num_beams, @@ -1496,14 +1497,14 @@ def generate( num_beam_hyps_to_keep=generation_config.num_return_sequences, num_beam_groups=generation_config.num_beam_groups, ) - # 11. interleave input_ids with `num_beams` additional sequences per batch + # 13. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) - # 12. run beam search + # 14. run beam search return self.group_beam_search( input_ids, beam_scorer, @@ -1573,7 +1574,7 @@ def typeerror(): constraint = PhrasalConstraint(word_ids) final_constraints.append(constraint) - # 10. prepare beam search scorer + # 12. prepare beam search scorer constrained_beam_scorer = ConstrainedBeamSearchScorer( constraints=final_constraints, batch_size=batch_size, @@ -1583,14 +1584,14 @@ def typeerror(): do_early_stopping=generation_config.early_stopping, num_beam_hyps_to_keep=generation_config.num_return_sequences, ) - # 11. interleave input_ids with `num_beams` additional sequences per batch + # 13. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) - # 12. run beam search + # 14. run beam search return self.constrained_beam_search( input_ids, constrained_beam_scorer=constrained_beam_scorer, From 7a92430fb2be63ad4834b4b25ca12dac5b175a97 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 22 Nov 2022 16:50:35 +0000 Subject: [PATCH 03/20] max_time test --- tests/models/gpt2/test_modeling_gpt2.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py index 2f6f8d12143d..118dcdea1f06 100644 --- a/tests/models/gpt2/test_modeling_gpt2.py +++ b/tests/models/gpt2/test_modeling_gpt2.py @@ -18,7 +18,7 @@ import math import unittest -from transformers import GPT2Config, is_torch_available +from transformers import GenerationConfig, GPT2Config, is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from ...generation.test_utils import GenerationTesterMixin @@ -727,6 +727,9 @@ def test_gpt2_sample(self): def test_gpt2_sample_max_time(self): tokenizer = GPT2Tokenizer.from_pretrained("gpt2") model = GPT2LMHeadModel.from_pretrained("gpt2") + # Pre-load the generation config to avoid a generate-time `from_pretrained`` call, which would interfere with + # the timing. + generation_config = GenerationConfig.from_pretrained("gpt2") model.to(torch_device) torch.manual_seed(0) @@ -736,31 +739,31 @@ def test_gpt2_sample_max_time(self): MAX_TIME = 0.5 start = datetime.datetime.now() - model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256) + model.generate(input_ids, generation_config, do_sample=True, max_time=MAX_TIME, max_length=256) duration = datetime.datetime.now() - start self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) start = datetime.datetime.now() - model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256) + model.generate(input_ids, generation_config, do_sample=False, max_time=MAX_TIME, max_length=256) duration = datetime.datetime.now() - start self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) start = datetime.datetime.now() - model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256) + model.generate(input_ids, generation_config, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256) duration = datetime.datetime.now() - start self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) start = datetime.datetime.now() - model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256) + model.generate(input_ids, generation_config, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256) duration = datetime.datetime.now() - start self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) start = datetime.datetime.now() - model.generate(input_ids, do_sample=False, max_time=None, max_length=256) + model.generate(input_ids, generation_config, do_sample=False, max_time=None, max_length=256) duration = datetime.datetime.now() - start self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) From 326ebabbcb6de1df58a00d478c60a0a1780a3604 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 23 Nov 2022 13:03:02 +0000 Subject: [PATCH 04/20] Load default gen config at model load time; Update docs --- .../en/main_classes/text_generation.mdx | 81 ++++++++++++++- .../generation/configuration_utils.py | 21 +++- src/transformers/generation/utils.py | 98 +++---------------- src/transformers/modeling_utils.py | 25 ++++- tests/models/gpt2/test_modeling_gpt2.py | 15 ++- utils/documentation_tests.txt | 1 + 6 files changed, 142 insertions(+), 99 deletions(-) diff --git a/docs/source/en/main_classes/text_generation.mdx b/docs/source/en/main_classes/text_generation.mdx index 78bef8bd5a25..0bdf5d1b1782 100644 --- a/docs/source/en/main_classes/text_generation.mdx +++ b/docs/source/en/main_classes/text_generation.mdx @@ -18,12 +18,91 @@ Each framework has a generate method for auto-regressive text generation impleme - TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`]. - Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`]. - +Regardless of your framework of choice, you can parametrize the generate method with a [`~generation.GenerationConfig`] +class instance. Please refer to this class for the complete list of generation parameters, which control the behavior +of the generation method. + + + +If you inspect a serialized [`~generation.GenerationConfig`] file or print a class instance, you will notice that +default values are omitted. Some attributes, like `max_new_tokens`, have a conservative default value, to avoid running +into resource limitations. Make sure you double-check the defaults in the documentation. + + + +All models have a default generation configuration that will be used if you don't provide one. In addition, you can +specify ad hoc modifications to the used generation configuration by passing the attribute you wish to override +directly to the generate method. Here's a few illustrative examples: + +- Greedy decoding, using the default generation configuration and ad hoc modifications: + +```python +>>> from transformers import AutoTokenizer, AutoModelForCausalLM + +>>> tokenizer = AutoTokenizer.from_pretrained("gpt2") +>>> model = AutoModelForCausalLM.from_pretrained("gpt2") + +>>> prompt = "Today I believe we can finally" +>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids + +>>> # Generate up to 30 tokens +>>> outputs = model.generate(input_ids, do_sample=False, max_length=30) +>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) +['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n'] +``` + +- Multinomial sampling, modifying an existing generation configuration: + +```python +>>> from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig +>>> import torch + +>>> tokenizer = AutoTokenizer.from_pretrained("gpt2") +>>> model = AutoModelForCausalLM.from_pretrained("gpt2") + +>>> prompt = "Today I believe we can finally" +>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids + +>>> # Sample up to 30 tokens +>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT +>>> generation_config = GenerationConfig.from_pretrained("gpt2") +>>> generation_config.max_length = 30 +>>> generation_config.do_sample = True +>>> outputs = model.generate(input_ids, generation_config=generation_config) +>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) +['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the'] +``` + +Beam-search decoding, using a freshly initialized generation configuration: + +```python +>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig + +>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") +>>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de") + +>>> sentence = "Paris is one of the densest populated areas in Europe." +>>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids + +>>> generation_config = GenerationConfig( +... max_length=64, +... num_beams=5, +... bos_token_id=0, +... eos_token_id=0, +... decoder_start_token_id=58100, +... pad_token_id=58100, +... bad_words_ids=[[58100]], +... ) +>>> outputs = model.generate(input_ids, generation_config=generation_config) +>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) +['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] +``` ## GenerationConfig [[autodoc]] generation.GenerationConfig - from_pretrained + - from_model_config - save_pretrained ## GenerationMixin diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index c8ffad4295e4..d086587d6b71 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -37,7 +37,23 @@ class GenerationConfig(PushToHubMixin): r""" - Class that holds a configuration for a generation task. + Class that holds a configuration for a generation task. A `generate` call supports the following generation methods + for text-decoder, text-to-text, speech-to-text, and vision-to-text models: + + - *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and + `do_sample=False`. + - *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.` + and `top_k>1` + - *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and + `do_sample=True`. + - *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and + `do_sample=False`. + - *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if + `num_beams>1` and `do_sample=True`. + - *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if + `num_beams>1` and `num_beam_groups>1`. + - *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if + `constraints!=None` or `force_words_ids!=None`. @@ -46,6 +62,9 @@ class GenerationConfig(PushToHubMixin): + Most of these parameters are explained in more detail in [this blog + post](https://huggingface.co/blog/how-to-generate). + Arg: > Parameters that control the length of the output diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index fe442cef2fcb..cdfd3a619d5b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -846,9 +846,8 @@ def _get_logits_processor( else begin_index + 1 ) if generation_config.forced_decoder_ids is not None: - begin_index += generation_config.forced_decoder_ids[-1][ - 0 - ] # generation starts after the last token that is forced + # generation starts after the last token that is forced + begin_index += generation_config.forced_decoder_ids[-1][0] processors.append( SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) ) @@ -994,37 +993,18 @@ def generate( ) -> Union[GenerateOutput, torch.LongTensor]: r""" - Generates sequences of token ids for models with a language modeling head. The method supports the following - generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models: - - - *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and - `do_sample=False`. - - *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.` - and `top_k>1` - - *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and - `do_sample=True`. - - *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and - `do_sample=False`. - - *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if - `num_beams>1` and `do_sample=True`. - - *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if - `num_beams>1` and `num_beam_groups>1`. - - *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if - `constraints!=None` or `force_words_ids!=None`. + Generates sequences of token ids for models with a language modeling head. - Apart from `inputs` and the other keyword arguments, all generation-controlling parameters will default to the - value of the corresponding attribute in `generation_config`. You can override these defaults by passing the - corresponding parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + Apart from `inputs` and the other keyword arguments documented here, all generation-controlling parameters will + default to the value of the corresponding attribute in `generation_config`. You can override these defaults by + passing the corresponding parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`. You can view the full list of attributes in [`~generation.GenerationConfig`]. - Most of these parameters are explained in more detail in [this blog - post](https://huggingface.co/blog/how-to-generate). - Parameters: inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the @@ -1034,10 +1014,10 @@ def generate( generation_config (`~generation.GenerationConfig`, *optional*): The generation configuration to be used as base parametrization for the generation call. `**kwargs` passed to generate matching the attributes of `generation_config` will override them. If - `generation_config` is not provided, a default will be used, with the following loading priority: 1) - from the `generation_config.json` model file, if it exists; 2) from the `self.config` model attribute, - containing a model configuration. Please note that unspecified parameters will inherit - [`~generation.GenerationConfig`]'s default values. + `generation_config` is not provided, the default will be used, which has the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. logits_processor (`LogitsProcessorList`, *optional*): Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a @@ -1079,66 +1059,10 @@ def generate( - [`~generation.SampleEncoderDecoderOutput`], - [`~generation.BeamSearchEncoderDecoderOutput`], - [`~generation.BeamSampleEncoderDecoderOutput`] - - Examples: - - Greedy Decoding: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM - - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> model = AutoModelForCausalLM.from_pretrained("gpt2") - - >>> prompt = "Today I believe we can finally" - >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids - - >>> # generate up to 30 tokens - >>> outputs = model.generate(input_ids, do_sample=False, max_length=30) - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n'] - ``` - - Multinomial Sampling: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> model = AutoModelForCausalLM.from_pretrained("gpt2") - - >>> prompt = "Today I believe we can finally" - >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids - - >>> # sample up to 30 tokens - >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT - >>> outputs = model.generate(input_ids, do_sample=True, max_length=30) - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the'] - ``` - - Beam-search decoding: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM - - >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") - >>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de") - - >>> sentence = "Paris is one of the densest populated areas in Europe." - >>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids - - >>> outputs = model.generate(input_ids, num_beams=5) - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] ```""" # 1. Handle `generation_config` and kwargs that might update it if generation_config is None: - try: - generation_config = GenerationConfig.from_pretrained(self.config.name_or_path) - except EnvironmentError: - generation_config = GenerationConfig.from_model_config(self.config) + generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 48e437fd7684..84e1e2e0a45b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -39,7 +39,7 @@ from .configuration_utils import PretrainedConfig from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled from .dynamic_module_utils import custom_object_save -from .generation import GenerationMixin +from .generation import GenerationConfig, GenerationMixin from .pytorch_utils import ( # noqa: F401 Conv1D, apply_chunking_to_forward, @@ -1023,6 +1023,7 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): # Save config and origin of the pretrained weights if given in model self.config = config self.name_or_path = config.name_or_path + self.generation_config = None # May be overwritten during `from_pretrained()` self.warnings_issued = {} def post_init(self): @@ -2477,6 +2478,28 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Set model in evaluation mode to deactivate DropOut modules by default model.eval() + # If it is a model with generation capabilities, load a default generation config (from the default path if + # it exists, otherwise from the model config) + if hasattr(cls, "prepare_inputs_for_generation"): + try: + generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + except EnvironmentError: + generation_config = GenerationConfig.from_model_config(config) + model.generation_config = generation_config + # Dispatch model with hooks on all devices if necessary if device_map is not None: dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index) diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py index 118dcdea1f06..2f6f8d12143d 100644 --- a/tests/models/gpt2/test_modeling_gpt2.py +++ b/tests/models/gpt2/test_modeling_gpt2.py @@ -18,7 +18,7 @@ import math import unittest -from transformers import GenerationConfig, GPT2Config, is_torch_available +from transformers import GPT2Config, is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from ...generation.test_utils import GenerationTesterMixin @@ -727,9 +727,6 @@ def test_gpt2_sample(self): def test_gpt2_sample_max_time(self): tokenizer = GPT2Tokenizer.from_pretrained("gpt2") model = GPT2LMHeadModel.from_pretrained("gpt2") - # Pre-load the generation config to avoid a generate-time `from_pretrained`` call, which would interfere with - # the timing. - generation_config = GenerationConfig.from_pretrained("gpt2") model.to(torch_device) torch.manual_seed(0) @@ -739,31 +736,31 @@ def test_gpt2_sample_max_time(self): MAX_TIME = 0.5 start = datetime.datetime.now() - model.generate(input_ids, generation_config, do_sample=True, max_time=MAX_TIME, max_length=256) + model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256) duration = datetime.datetime.now() - start self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) start = datetime.datetime.now() - model.generate(input_ids, generation_config, do_sample=False, max_time=MAX_TIME, max_length=256) + model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256) duration = datetime.datetime.now() - start self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) start = datetime.datetime.now() - model.generate(input_ids, generation_config, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256) + model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256) duration = datetime.datetime.now() - start self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) start = datetime.datetime.now() - model.generate(input_ids, generation_config, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256) + model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256) duration = datetime.datetime.now() - start self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) start = datetime.datetime.now() - model.generate(input_ids, generation_config, do_sample=False, max_time=None, max_length=256) + model.generate(input_ids, do_sample=False, max_time=None, max_length=256) duration = datetime.datetime.now() - start self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index 9293dc3c3934..abc591b74e17 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -3,6 +3,7 @@ docs/source/es/quicktour.mdx docs/source/en/pipeline_tutorial.mdx docs/source/en/autoclass_tutorial.mdx docs/source/en/task_summary.mdx +docs/source/en/main_classes/text_generation.mdx docs/source/en/model_doc/markuplm.mdx docs/source/en/model_doc/speech_to_text.mdx docs/source/en/model_doc/switch_transformers.mdx From 3481fd71cb45e6c99840cb679ce4906fb942342f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 23 Nov 2022 16:40:58 +0000 Subject: [PATCH 05/20] further documentation; add tests --- .../en/main_classes/text_generation.mdx | 102 +++++++----------- src/transformers/generation/utils.py | 74 ++++++++++++- tests/generation/test_configuration_utils.py | 33 +++++- utils/documentation_tests.txt | 1 - 4 files changed, 142 insertions(+), 68 deletions(-) diff --git a/docs/source/en/main_classes/text_generation.mdx b/docs/source/en/main_classes/text_generation.mdx index 0bdf5d1b1782..6226b3c0e05c 100644 --- a/docs/source/en/main_classes/text_generation.mdx +++ b/docs/source/en/main_classes/text_generation.mdx @@ -22,81 +22,59 @@ Regardless of your framework of choice, you can parametrize the generate method class instance. Please refer to this class for the complete list of generation parameters, which control the behavior of the generation method. - - -If you inspect a serialized [`~generation.GenerationConfig`] file or print a class instance, you will notice that -default values are omitted. Some attributes, like `max_new_tokens`, have a conservative default value, to avoid running -into resource limitations. Make sure you double-check the defaults in the documentation. +All models have a default generation configuration that will be used if you don't provide one. If you have a loaded +model instance `model`, you can inspect the default generation configuration with `model.generation_config`. If you'd +like to set a new default generation configuration, you can create a new [`~generation.GenerationConfig`] instance and +store it with `save_pretrained`, making sure to leave its `config_file_name` argument empty. - +```python +from transformers import AutoModelForCausalLM, GenerationConfig -All models have a default generation configuration that will be used if you don't provide one. In addition, you can -specify ad hoc modifications to the used generation configuration by passing the attribute you wish to override -directly to the generate method. Here's a few illustrative examples: +model = AutoModelForCausalLM.from_pretrained("my_account/my_model") -- Greedy decoding, using the default generation configuration and ad hoc modifications: +# Inspect the default generation configuration +print(model.generation_config) -```python ->>> from transformers import AutoTokenizer, AutoModelForCausalLM +# Set a new default generation configuration +generation_config = GenerationConfig( + max_new_tokens=50, do_sample=True, top_k=50, eos_token_id=model.config.eos_token_id +) +generation_config.save_pretrained("my_account/my_model", push_to_hub=True) +``` ->>> tokenizer = AutoTokenizer.from_pretrained("gpt2") ->>> model = AutoModelForCausalLM.from_pretrained("gpt2") + ->>> prompt = "Today I believe we can finally" ->>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids +If you inspect a serialized [`~generation.GenerationConfig`] file or print a class instance, you will notice that +default values are omitted. Some attributes, like `max_new_tokens`, have a conservative default value, to avoid running +into resource limitations. Make sure you double-check the defaults in the documentation. ->>> # Generate up to 30 tokens ->>> outputs = model.generate(input_ids, do_sample=False, max_length=30) ->>> tokenizer.batch_decode(outputs, skip_special_tokens=True) -['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n'] -``` + -- Multinomial sampling, modifying an existing generation configuration: +You can also store several generation parametrizations in a single directory, making use of the `config_file_name` +argument in `save_pretrained`. You can latter instantiate them with `from_pretrained`. This is useful if you want to +store several generation configurations for a single model (e.g. one for creative text generation with sampling, and +other for summarization with beam search). ```python ->>> from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig ->>> import torch - ->>> tokenizer = AutoTokenizer.from_pretrained("gpt2") ->>> model = AutoModelForCausalLM.from_pretrained("gpt2") - ->>> prompt = "Today I believe we can finally" ->>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids - ->>> # Sample up to 30 tokens ->>> torch.manual_seed(0) # doctest: +IGNORE_RESULT ->>> generation_config = GenerationConfig.from_pretrained("gpt2") ->>> generation_config.max_length = 30 ->>> generation_config.do_sample = True ->>> outputs = model.generate(input_ids, generation_config=generation_config) ->>> tokenizer.batch_decode(outputs, skip_special_tokens=True) -['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the'] +from transformers import GenerationConfig + +# Create a new generation configuration for later use +my_awesome_config = GenerationConfig(num_beams=8, do_sample=False, bad_words_ids=[[42, 43, 44]], eos_token_id=2) +my_awesome_config.save_pretrained( + "my_account/my_model", config_file_name="awesome_generation_config.json", push_to_hub=True +) + +# Restore the generation configuration +generation_config = GenerationConfig.from_pretrained( + "my_account/my_model", config_file_name="awesome_generation_config.json" +) ``` -Beam-search decoding, using a freshly initialized generation configuration: +Finally, you can specify ad hoc modifications to the used generation configuration by passing the attribute you +wish to override directly to the generate method (e.g. `model.generate(inputs, max_new_tokens=512)`). Each +framework's `generate` method docstring (available below) has a few illustrative examples on the different strategies +to parameterize it. -```python ->>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig - ->>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") ->>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de") - ->>> sentence = "Paris is one of the densest populated areas in Europe." ->>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids - ->>> generation_config = GenerationConfig( -... max_length=64, -... num_beams=5, -... bos_token_id=0, -... eos_token_id=0, -... decoder_start_token_id=58100, -... pad_token_id=58100, -... bad_words_ids=[[58100]], -... ) ->>> outputs = model.generate(input_ids, generation_config=generation_config) ->>> tokenizer.batch_decode(outputs, skip_special_tokens=True) -['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] -``` ## GenerationConfig diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index cdfd3a619d5b..201f170f6735 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -997,11 +997,12 @@ def generate( - Apart from `inputs` and the other keyword arguments documented here, all generation-controlling parameters will - default to the value of the corresponding attribute in `generation_config`. You can override these defaults by - passing the corresponding parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`. - You can view the full list of attributes in [`~generation.GenerationConfig`]. + For a complete overview of generate, check the [following + guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). @@ -1059,6 +1060,71 @@ def generate( - [`~generation.SampleEncoderDecoderOutput`], - [`~generation.BeamSearchEncoderDecoderOutput`], - [`~generation.BeamSampleEncoderDecoderOutput`] + + Examples: + + Greedy decoding, using the default generation configuration and ad hoc modifications: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> prompt = "Today I believe we can finally" + >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + >>> # Generate up to 30 tokens + >>> outputs = model.generate(input_ids, do_sample=False, max_length=30) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n'] + ``` + + Multinomial sampling, modifying an existing generation configuration: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> prompt = "Today I believe we can finally" + >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + >>> # Sample up to 30 tokens + >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT + >>> generation_config = GenerationConfig.from_pretrained("gpt2") + >>> generation_config.max_length = 30 + >>> generation_config.do_sample = True + >>> outputs = model.generate(input_ids, generation_config=generation_config) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the'] + ``` + + Beam-search decoding, using a freshly initialized generation configuration: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig + + >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de") + + >>> sentence = "Paris is one of the densest populated areas in Europe." + >>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids + + >>> generation_config = GenerationConfig( + ... max_length=64, + ... num_beams=5, + ... bos_token_id=0, + ... eos_token_id=0, + ... decoder_start_token_id=58100, + ... pad_token_id=58100, + ... bad_words_ids=[[58100]], + ... ) + >>> outputs = model.generate(input_ids, generation_config=generation_config) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] ```""" # 1. Handle `generation_config` and kwargs that might update it if generation_config is None: diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index 5cfe0995655f..004720e110b9 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import tempfile import unittest from parameterized import parameterized -from transformers.generation import GenerationConfig +from transformers import AutoConfig, GenerationConfig class LogitsProcessorTest(unittest.TestCase): @@ -43,3 +44,33 @@ def test_save_load_config(self, config_name): self.assertEqual(loaded_config.top_k, 50) self.assertEqual(loaded_config.max_length, 20) self.assertEqual(loaded_config.max_time, None) + + def test_from_model_config(self): + model_config = AutoConfig.from_pretrained("gpt2") + generation_config_from_model = GenerationConfig.from_model_config(model_config) + default_generation_config = GenerationConfig() + + # The generation config has loaded a few non-default parameters from the model config + self.assertNotEqual(generation_config_from_model, default_generation_config) + + # One of those parameters is eos_token_id -- check if it matches + self.assertNotEqual(generation_config_from_model.eos_token_id, default_generation_config.eos_token_id) + self.assertEqual(generation_config_from_model.eos_token_id, model_config.eos_token_id) + + def test_update(self): + generation_config = GenerationConfig() + update_kwargs = { + "max_new_tokens": 1024, + "foo": "bar", + } + update_kwargs_copy = copy.deepcopy(update_kwargs) + unused_kwargs = generation_config.update(**update_kwargs) + + # update_kwargs was not modified (no side effects) + self.assertEqual(update_kwargs, update_kwargs_copy) + + # update_kwargs was used to update the config on valid attributes + self.assertEqual(generation_config.max_new_tokens, 1024) + + # `.update()` returns a dictionary of unused kwargs + self.assertEqual(unused_kwargs, {"foo": "bar"}) diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index abc591b74e17..9293dc3c3934 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -3,7 +3,6 @@ docs/source/es/quicktour.mdx docs/source/en/pipeline_tutorial.mdx docs/source/en/autoclass_tutorial.mdx docs/source/en/task_summary.mdx -docs/source/en/main_classes/text_generation.mdx docs/source/en/model_doc/markuplm.mdx docs/source/en/model_doc/speech_to_text.mdx docs/source/en/model_doc/switch_transformers.mdx From 1b182eec238d661aea6332d714b2bd2593cd3827 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 23 Nov 2022 17:19:43 +0000 Subject: [PATCH 06/20] adapt rag to the new structure --- .../generation/configuration_utils.py | 11 +- src/transformers/models/rag/modeling_rag.py | 196 +++++------------- 2 files changed, 62 insertions(+), 145 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index d086587d6b71..7a58f1f60d8e 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -604,9 +604,16 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig" config_dict = model_config.to_dict() config = cls.from_dict(config_dict, return_unused_kwargs=False) - # Handles a few special cases - if config.eos_token_id is None and hasattr(config_dict, "decoder"): + # Special cases: + # 1. eos_token_id defined in the decoder + if config.eos_token_id is None and "decoder" in config_dict: config.eos_token_id = config_dict["decoder"]["eos_token_id"] + # 2. RAG + if "generator" in config_dict: + config.bos_token_id = config_dict["generator"]["bos_token_id"] + config.eos_token_id = config_dict["generator"]["eos_token_id"] + config.pad_token_id = config_dict["generator"]["pad_token_id"] + config.decoder_start_token_id = config_dict["generator"]["decoder_start_token_id"] return config diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index c4b102a204f6..461e06ec4f75 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -14,6 +14,7 @@ # limitations under the License. """RAG model implementation.""" +import copy from dataclasses import dataclass from typing import Callable, List, Optional, Tuple, Union @@ -21,7 +22,7 @@ from torch import nn from ...configuration_utils import PretrainedConfig -from ...generation import BeamSearchScorer, LogitsProcessorList, StoppingCriteriaList +from ...generation import BeamSearchScorer, GenerationConfig, LogitsProcessorList, StoppingCriteriaList from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings @@ -1384,33 +1385,12 @@ def generate( context_input_ids: Optional[torch.LongTensor] = None, context_attention_mask: Optional[torch.LongTensor] = None, doc_scores: Optional[torch.FloatTensor] = None, - max_length: Optional[int] = None, - min_length: Optional[int] = None, - early_stopping: Optional[bool] = None, - use_cache: Optional[bool] = None, - num_beams: Optional[int] = None, - num_beam_groups: Optional[int] = None, - diversity_penalty: Optional[float] = None, - bos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - length_penalty: Optional[float] = None, - no_repeat_ngram_size: Optional[int] = None, - encoder_no_repeat_ngram_size: Optional[int] = None, - repetition_penalty: Optional[float] = None, - bad_words_ids: Optional[List[List[int]]] = None, - num_return_sequences: Optional[int] = None, - decoder_start_token_id: Optional[int] = None, n_docs: Optional[int] = None, + generation_config: Optional[GenerationConfig] = None, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None, logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), - renormalize_logits: Optional[bool] = None, stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), - forced_bos_token_id: Optional[int] = None, - forced_eos_token_id: Optional[int] = None, - remove_invalid_values: Optional[bool] = None, - exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, - **model_kwargs + **kwargs ) -> torch.LongTensor: """ Implements RAG token decoding. @@ -1444,51 +1424,15 @@ def generate( If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`]. - max_length (`int`, *optional*, defaults to 20): - The maximum length of the sequence to be generated. - min_length (`int`, *optional*, defaults to 10): - The minimum length of the sequence to be generated. - early_stopping (`bool`, *optional*, defaults to `False`): - Whether or not to stop the beam search when at least `num_beams` sentences are finished per batch or - not. - use_cache: (`bool`, *optional*, defaults to `True`): - Whether or not the model should use the past last key/values attentions (if applicable to the model) to - speed up decoding. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - bos_token_id (`int`, *optional*): - The id of the *beginning-of-sequence* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. - length_penalty (`float`, *optional*, defaults to 1.0): - Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent - to the sequence length, which in turn is used to divide the score of the sequence. Since the score is - the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, - while `length_penalty` < 0.0 encourages shorter sequences. - no_repeat_ngram_size (`int`, *optional*, defaults to 0): - If set to int > 0, all ngrams of that size can only occur once. - encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0): - If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the - `decoder_input_ids`. - bad_words_ids(`List[int]`, *optional*): - List of token ids that are not allowed to be generated. In order to get the tokens of the words that - should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`. - num_beams (`int`, *optional*, defaults to 1): - Number of beams for beam search. 1 means no beam search. - num_beam_groups (`int`, *optional*, defaults to 1): - Number of groups to divide `num_beams` into in order to ensure diversity among different groups of - beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. - diversity_penalty (`float`, *optional*, defaults to 0.0): - This value is subtracted from a beam's score if it generates a token same as any beam from other group - at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is - enabled. - num_return_sequences(`int`, *optional*, defaults to 1): - The number of independently computed returned sequences for each element in the batch. Note that this - is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`] function, where - we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an - encoder-decoder model starts decoding with a different token than *bos*, the id of that token. n_docs (`int`, *optional*, defaults to `config.n_docs`) Number of documents to retrieve and/or number of documents for which to generate an answer. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which has the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. prefix_allowed_tokens_fn: (`Callable[[int, torch.Tensor], List[int]]`, *optional*): If provided, this function constraints the beam search to allowed tokens only at each step. If not provided no constraint is applied. This function takes 2 arguments `inputs_ids` and the batch ID @@ -1497,53 +1441,30 @@ def generate( constrained generation conditioned on the prefix, as described in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904). logits_processor (`LogitsProcessorList`, *optional*): - Custom logits processors that complement the default logits processors built from arguments and a - model's config. If a logit processor is passed that is already created with the arguments or a model's - config an error is thrown. + Custom logits processors that complement the default logits processors built from arguments and a + model's config. If a logit processor is passed that is already created with the arguments or a model's + config an error is thrown. stopping_criteria (`StoppingCriteriaList`, *optional*): - Custom stopping criteria that complement the default stopping criteria built from arguments and a - model's config. If a stopping criteria is passed that is already created with the arguments or a - model's config an error is thrown. - forced_bos_token_id (`int`, *optional*): - The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful - for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be - the target language token. - forced_eos_token_id (`int`, *optional*): - The id of the token to force as the last generated token when `max_length` is reached. - remove_invalid_values (`bool`, *optional*): - Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to - crash. Note that using `remove_invalid_values` can slow down generation. + Custom stopping criteria that complement the default stopping criteria built from arguments and a + model's config. If a stopping criteria is passed that is already created with the arguments or a + model's config an error is thrown. + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. Return: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. """ + # Handle `generation_config` and kwargs that might update it + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + # set default parameters n_docs = n_docs if n_docs is not None else self.config.n_docs - num_beams = num_beams if num_beams is not None else self.config.num_beams - num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups - max_length = max_length if max_length is not None else self.config.max_length - num_return_sequences = ( - num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences - ) - bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id - pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id - use_cache = use_cache if use_cache is not None else self.config.use_cache - decoder_start_token_id = ( - decoder_start_token_id - if decoder_start_token_id is not None - else self.config.generator.decoder_start_token_id - ) - remove_invalid_values = ( - remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values - ) - exponential_decay_length_penalty = ( - exponential_decay_length_penalty - if exponential_decay_length_penalty is not None - else self.config.exponential_decay_length_penalty - ) # retrieve docs if self.retriever is not None and context_input_ids is None: @@ -1583,8 +1504,8 @@ def generate( encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True) input_ids = torch.full( - (batch_size * num_beams, 1), - decoder_start_token_id, + (batch_size * generation_config.num_beams, 1), + generation_config.decoder_start_token_id, dtype=torch.long, device=next(self.parameters()).device, ) @@ -1600,10 +1521,12 @@ def extend_enc_output(tensor, num_beams=None): return tensor.reshape((batch_size * num_beams * n_docs,) + tensor.shape[3:]) # correctly extend last_hidden_state and attention mask - context_attention_mask = extend_enc_output(context_attention_mask, num_beams=num_beams) - encoder_outputs["last_hidden_state"] = extend_enc_output(last_hidden_state, num_beams=num_beams) + context_attention_mask = extend_enc_output(context_attention_mask, num_beams=generation_config.num_beams) + encoder_outputs["last_hidden_state"] = extend_enc_output( + last_hidden_state, num_beams=generation_config.num_beams + ) - doc_scores = doc_scores.repeat_interleave(num_beams, dim=0) + doc_scores = doc_scores.repeat_interleave(generation_config.num_beams, dim=0) # define start_len & additional parameters model_kwargs["doc_scores"] = doc_scores @@ -1612,64 +1535,51 @@ def extend_enc_output(tensor, num_beams=None): model_kwargs["n_docs"] = n_docs pre_processor = self._get_logits_processor( - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, + generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, encoder_input_ids=context_input_ids, - bad_words_ids=bad_words_ids, - min_length=min_length, - max_length=max_length, - eos_token_id=eos_token_id, - forced_bos_token_id=forced_bos_token_id, - forced_eos_token_id=forced_eos_token_id, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - num_beams=num_beams, - num_beam_groups=num_beam_groups, - diversity_penalty=diversity_penalty, - remove_invalid_values=remove_invalid_values, - exponential_decay_length_penalty=exponential_decay_length_penalty, logits_processor=logits_processor, - renormalize_logits=renormalize_logits, ) - if num_beams == 1: - if num_return_sequences > 1: + if generation_config.num_beams == 1: + if generation_config.num_return_sequences > 1: raise ValueError( - f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " greedy search." ) return self.greedy_search( input_ids, logits_processor=pre_processor, - max_length=max_length, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, + max_length=generation_config.max_length, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, **model_kwargs, ) - elif num_beams > 1: - length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty - early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping - if num_return_sequences > num_beams: + elif generation_config.num_beams > 1: + if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") beam_scorer = BeamSearchScorer( batch_size=batch_size, - num_beams=num_beams, + num_beams=generation_config.num_beams, device=self.device, - length_penalty=length_penalty, - do_early_stopping=early_stopping, - num_beam_hyps_to_keep=num_return_sequences, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, ) return self.beam_search( input_ids, beam_scorer, logits_processor=pre_processor, - max_length=max_length, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, + max_length=generation_config.max_length, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, **model_kwargs, ) else: - raise ValueError(f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {num_beams}") + raise ValueError( + f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {generation_config.num_beams}" + ) def get_input_embeddings(self): return self.rag.generator.get_input_embeddings() From 444783b76dbee6ccb8cf853ac6fb3a643622c3b0 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 23 Nov 2022 20:02:46 +0000 Subject: [PATCH 07/20] handle models not instantiated with from_pretained (like in tests) --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 201f170f6735..95309cad1fda 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1128,7 +1128,7 @@ def generate( ```""" # 1. Handle `generation_config` and kwargs that might update it if generation_config is None: - generation_config = self.generation_config + generation_config = self.generation_config if self.generation_config is not None else GenerationConfig() generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs From d54cd5cd4c65af20fafc93aa6cdb5e4056eb4d95 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 24 Nov 2022 12:13:32 +0000 Subject: [PATCH 08/20] better default generation config --- src/transformers/generation/utils.py | 2 +- src/transformers/modeling_utils.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 95309cad1fda..201f170f6735 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1128,7 +1128,7 @@ def generate( ```""" # 1. Handle `generation_config` and kwargs that might update it if generation_config is None: - generation_config = self.generation_config if self.generation_config is not None else GenerationConfig() + generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 84e1e2e0a45b..ac7163fed1a1 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1023,8 +1023,9 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): # Save config and origin of the pretrained weights if given in model self.config = config self.name_or_path = config.name_or_path - self.generation_config = None # May be overwritten during `from_pretrained()` self.warnings_issued = {} + # May be overwritten during `from_pretrained()`, if there is a `generation_config.json` in the model folder + self.generation_config = GenerationConfig.from_model_config(config) def post_init(self): """ @@ -2478,8 +2479,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Set model in evaluation mode to deactivate DropOut modules by default model.eval() - # If it is a model with generation capabilities, load a default generation config (from the default path if - # it exists, otherwise from the model config) + # If it is a model with generation capabilities, attempt to set the generation config to an existing + # `generation_config.json` file. Otherwise, keep the generation config created from the model config. if hasattr(cls, "prepare_inputs_for_generation"): try: generation_config = GenerationConfig.from_pretrained( @@ -2496,9 +2497,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P _from_pipeline=from_pipeline, **kwargs, ) + model.generation_config = generation_config except EnvironmentError: - generation_config = GenerationConfig.from_model_config(config) - model.generation_config = generation_config + pass # Dispatch model with hooks on all devices if necessary if device_map is not None: From 519ef61734721fc3cd2b8e38bcb16e6661a56fb4 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 24 Nov 2022 17:06:28 +0000 Subject: [PATCH 09/20] add can_generate fn --- src/transformers/generation/utils.py | 61 ++++++++++++++-------------- src/transformers/modeling_utils.py | 19 +++++++-- 2 files changed, 46 insertions(+), 34 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 201f170f6735..8fb26fafaadd 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -936,7 +936,7 @@ def _validate_model_class(self): Confirms that the model class is compatible with generation. If not, raises an exception that points to the right class to use. """ - if not hasattr(self, "prepare_inputs_for_generation"): + if not self.can_generate(): generate_compatible_mappings = [ MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, @@ -1126,17 +1126,16 @@ def generate( >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] ```""" - # 1. Handle `generation_config` and kwargs that might update it + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + self._validate_model_class() + if generation_config is None: generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs - - # 2. Validate the `.generate()` call - self._validate_model_class() self._validate_model_kwargs(model_kwargs.copy()) - # 3. Set generation parameters if not already defined + # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() @@ -1151,7 +1150,7 @@ def generate( ) generation_config.pad_token_id = generation_config.eos_token_id - # 4. Define model inputs + # 3. Define model inputs # inputs_tensor has to be defined # model_input_name is defined if model-specific keyword input is passed # otherwise model_input_name is None @@ -1161,7 +1160,7 @@ def generate( ) batch_size = inputs_tensor.shape[0] - # 5. Define other model kwargs + # 4. Define other model kwargs model_kwargs["output_attentions"] = generation_config.output_attentions model_kwargs["output_hidden_states"] = generation_config.output_hidden_states model_kwargs["use_cache"] = generation_config.use_cache @@ -1192,7 +1191,7 @@ def generate( inputs_tensor, model_kwargs, model_input_name ) - # 6. Prepare `input_ids` which will be used for auto-regressive generation + # 5. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: input_ids = self._prepare_decoder_input_ids_for_generation( batch_size, @@ -1205,7 +1204,7 @@ def generate( # if decoder-only then inputs_tensor has to be `input_ids` input_ids = inputs_tensor - # 7. Prepare `max_length` depending on other stopping criteria. + # 6. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length == 20 if has_default_max_length and generation_config.max_new_tokens is None: @@ -1239,7 +1238,7 @@ def generate( " increasing `max_new_tokens`." ) - # 8. determine generation mode + # 7. determine generation mode is_constraint_gen_mode = ( generation_config.constraints is not None or generation_config.force_words_ids is not None ) @@ -1305,7 +1304,7 @@ def generate( UserWarning, ) - # 9. prepare distribution pre_processing samplers + # 8. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, @@ -1314,11 +1313,11 @@ def generate( logits_processor=logits_processor, ) - # 10. prepare stopping criteria + # 9. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) - # 11. go into different generation modes + # 10. go into different generation modes if is_greedy_gen_mode: if generation_config.num_return_sequences > 1: raise ValueError( @@ -1326,7 +1325,7 @@ def generate( " greedy search." ) - # 12. run greedy search + # 11. run greedy search return self.greedy_search( input_ids, logits_processor=logits_processor, @@ -1362,10 +1361,10 @@ def generate( ) elif is_sample_gen_mode: - # 12. prepare logits warper + # 11. prepare logits warper logits_warper = self._get_logits_warper(generation_config) - # 13. expand input_ids with `num_return_sequences` additional sequences per batch + # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_return_sequences, @@ -1373,7 +1372,7 @@ def generate( **model_kwargs, ) - # 14. run sample + # 13. run sample return self.sample( input_ids, logits_processor=logits_processor, @@ -1394,7 +1393,7 @@ def generate( if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") - # 12. prepare beam search scorer + # 11. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=generation_config.num_beams, @@ -1403,14 +1402,14 @@ def generate( do_early_stopping=generation_config.early_stopping, num_beam_hyps_to_keep=generation_config.num_return_sequences, ) - # 13. interleave input_ids with `num_beams` additional sequences per batch + # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) - # 14. run beam search + # 13. run beam search return self.beam_search( input_ids, beam_scorer, @@ -1425,12 +1424,12 @@ def generate( ) elif is_beam_sample_gen_mode: - # 12. prepare logits warper + # 11. prepare logits warper logits_warper = self._get_logits_warper(generation_config) if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") - # 13. prepare beam search scorer + # 12. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size * generation_config.num_return_sequences, num_beams=generation_config.num_beams, @@ -1439,7 +1438,7 @@ def generate( do_early_stopping=generation_config.early_stopping, ) - # 14. interleave input_ids with `num_beams` additional sequences per batch + # 13. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams * generation_config.num_return_sequences, @@ -1447,7 +1446,7 @@ def generate( **model_kwargs, ) - # 15. run beam sample + # 14. run beam sample return self.beam_sample( input_ids, beam_scorer, @@ -1476,7 +1475,7 @@ def generate( if not has_default_typical_p: raise ValueError("Decoder argument `typical_p` is not supported with beam groups.") - # 12. prepare beam search scorer + # 11. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=generation_config.num_beams, @@ -1487,14 +1486,14 @@ def generate( num_beam_hyps_to_keep=generation_config.num_return_sequences, num_beam_groups=generation_config.num_beam_groups, ) - # 13. interleave input_ids with `num_beams` additional sequences per batch + # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) - # 14. run beam search + # 13. run beam search return self.group_beam_search( input_ids, beam_scorer, @@ -1564,7 +1563,7 @@ def typeerror(): constraint = PhrasalConstraint(word_ids) final_constraints.append(constraint) - # 12. prepare beam search scorer + # 11. prepare beam search scorer constrained_beam_scorer = ConstrainedBeamSearchScorer( constraints=final_constraints, batch_size=batch_size, @@ -1574,14 +1573,14 @@ def typeerror(): do_early_stopping=generation_config.early_stopping, num_beam_hyps_to_keep=generation_config.num_return_sequences, ) - # 13. interleave input_ids with `num_beams` additional sequences per batch + # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) - # 14. run beam search + # 13. run beam search return self.constrained_beam_search( input_ids, constrained_beam_scorer=constrained_beam_scorer, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ac7163fed1a1..baa3f66bf040 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1024,8 +1024,12 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): self.config = config self.name_or_path = config.name_or_path self.warnings_issued = {} - # May be overwritten during `from_pretrained()`, if there is a `generation_config.json` in the model folder - self.generation_config = GenerationConfig.from_model_config(config) + + if self.can_generate(): + # May be overwritten during `from_pretrained()`, if there is a `generation_config.json` in the model folder + self.generation_config = GenerationConfig.from_model_config(config) + else: + self.generation_config = None def post_init(self): """ @@ -1108,6 +1112,15 @@ def base_model(self) -> nn.Module: """ return getattr(self, self.base_model_prefix, self) + def can_generate(self) -> bool: + """ + Returns whether this model can generate sequences with `.generate()`. + + Returns: + :bool: Whether this model can generate sequences with `.generate()`. + """ + return hasattr(self, "prepare_inputs_for_generation") + def get_input_embeddings(self) -> nn.Module: """ Returns the model's input embeddings. @@ -2481,7 +2494,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # If it is a model with generation capabilities, attempt to set the generation config to an existing # `generation_config.json` file. Otherwise, keep the generation config created from the model config. - if hasattr(cls, "prepare_inputs_for_generation"): + if model.can_generate(): try: generation_config = GenerationConfig.from_pretrained( pretrained_model_name_or_path, From e628471cf39dedb27ea5dd36a3a2ad9bd956ad32 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 24 Nov 2022 19:13:40 +0000 Subject: [PATCH 10/20] handle legacy use case of ad hoc model config changes --- .../generation/configuration_utils.py | 17 +-- src/transformers/generation/utils.py | 136 ++++++++++-------- src/transformers/modeling_utils.py | 9 +- 3 files changed, 88 insertions(+), 74 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 7a58f1f60d8e..87f2545bbe83 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -604,16 +604,13 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig" config_dict = model_config.to_dict() config = cls.from_dict(config_dict, return_unused_kwargs=False) - # Special cases: - # 1. eos_token_id defined in the decoder - if config.eos_token_id is None and "decoder" in config_dict: - config.eos_token_id = config_dict["decoder"]["eos_token_id"] - # 2. RAG - if "generator" in config_dict: - config.bos_token_id = config_dict["generator"]["bos_token_id"] - config.eos_token_id = config_dict["generator"]["eos_token_id"] - config.pad_token_id = config_dict["generator"]["pad_token_id"] - config.decoder_start_token_id = config_dict["generator"]["decoder_start_token_id"] + # Special cases: tokens defined in nested parts of the config + for component_name in ("decoder", "generator"): + if component_name in config_dict: + config.bos_token_id = config_dict[component_name]["bos_token_id"] + config.eos_token_id = config_dict[component_name]["eos_token_id"] + config.pad_token_id = config_dict[component_name]["pad_token_id"] + config.decoder_start_token_id = config_dict[component_name]["decoder_start_token_id"] return config diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8fb26fafaadd..c44fb6c64106 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -622,26 +622,16 @@ def _prepare_decoder_input_ids_for_generation( def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: decoder_start_token_id = ( - decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id + decoder_start_token_id + if decoder_start_token_id is not None + else self.generation_config.decoder_start_token_id ) - bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id if decoder_start_token_id is not None: return decoder_start_token_id - elif ( - hasattr(self.config, "decoder") - and hasattr(self.config.decoder, "decoder_start_token_id") - and self.config.decoder.decoder_start_token_id is not None - ): - return self.config.decoder.decoder_start_token_id elif bos_token_id is not None: return bos_token_id - elif ( - hasattr(self.config, "decoder") - and hasattr(self.config.decoder, "bos_token_id") - and self.config.decoder.bos_token_id is not None - ): - return self.config.decoder.bos_token_id raise ValueError( "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." ) @@ -1129,8 +1119,12 @@ def generate( # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() + # priority: `generation_config` argument > `model.generation_config` > generation config from `model.config` if generation_config is None: - generation_config = self.generation_config + if self.generation_config is not None: + generation_config = self.generation_config + else: + generation_config = GenerationConfig.from_model_config(self.config) generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs self._validate_model_kwargs(model_kwargs.copy()) @@ -1686,15 +1680,19 @@ def contrastive_search( logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples @@ -2041,15 +2039,19 @@ def greedy_search( UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples @@ -2289,15 +2291,19 @@ def sample( ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples @@ -2541,15 +2547,19 @@ def beam_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) @@ -2861,15 +2871,19 @@ def beam_sample( UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) @@ -3170,15 +3184,19 @@ def group_beam_search( UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) @@ -3539,15 +3557,19 @@ def constrained_beam_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - output_scores = output_scores if output_scores is not None else self.config.output_scores - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( - return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index baa3f66bf040..04dbb47351a2 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1024,12 +1024,7 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): self.config = config self.name_or_path = config.name_or_path self.warnings_issued = {} - - if self.can_generate(): - # May be overwritten during `from_pretrained()`, if there is a `generation_config.json` in the model folder - self.generation_config = GenerationConfig.from_model_config(config) - else: - self.generation_config = None + self.generation_config = None def post_init(self): """ @@ -2493,7 +2488,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model.eval() # If it is a model with generation capabilities, attempt to set the generation config to an existing - # `generation_config.json` file. Otherwise, keep the generation config created from the model config. + # `generation_config.json` file. if model.can_generate(): try: generation_config = GenerationConfig.from_pretrained( From e1acd08cbbd4394b22ad951c57d13cd7f2f6a67d Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 24 Nov 2022 19:42:09 +0000 Subject: [PATCH 11/20] initialize gen config from config in individual methods, if gen config is none --- src/transformers/generation/utils.py | 133 ++++++++++++++++----------- 1 file changed, 77 insertions(+), 56 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c44fb6c64106..d694354d4471 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1677,22 +1677,25 @@ def contrastive_search( ['DeepMind Company is a company that focuses on the development and commercialization of artificial intelligence (AI). DeepMind’s mission is to help people understand and solve problems that are difficult to solve in the world today.\n\nIn this post, we talk about the benefits of deep learning in business and how it'] ```""" # init values + generation_config = ( + self.generation_config + if self.generation_config is not None + else GenerationConfig.from_model_config(self.config) + ) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) + pad_token_id = pad_token_id if pad_token_id is not None else generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else generation_config.output_scores + output_attentions = output_attentions if output_attentions is not None else generation_config.output_attentions output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate + else generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples @@ -2030,6 +2033,11 @@ def greedy_search( ["It might be possible to get a better understanding of the nature of the problem, but it's not"] ```""" # init values + generation_config = ( + self.generation_config + if self.generation_config is not None + else GenerationConfig.from_model_config(self.config) + ) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: @@ -2039,19 +2047,17 @@ def greedy_search( UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) + pad_token_id = pad_token_id if pad_token_id is not None else generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else generation_config.output_scores + output_attentions = output_attentions if output_attentions is not None else generation_config.output_attentions output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate + else generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples @@ -2281,6 +2287,11 @@ def sample( ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] ```""" # init values + generation_config = ( + self.generation_config + if self.generation_config is not None + else GenerationConfig.from_model_config(self.config) + ) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: @@ -2291,19 +2302,17 @@ def sample( ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) + pad_token_id = pad_token_id if pad_token_id is not None else generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else generation_config.output_scores + output_attentions = output_attentions if output_attentions is not None else generation_config.output_attentions output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate + else generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples @@ -2536,6 +2545,11 @@ def beam_search( ['Wie alt bist du?'] ```""" # init values + generation_config = ( + self.generation_config + if self.generation_config is not None + else GenerationConfig.from_model_config(self.config) + ) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: @@ -2547,19 +2561,17 @@ def beam_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) + pad_token_id = pad_token_id if pad_token_id is not None else generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else generation_config.output_scores + output_attentions = output_attentions if output_attentions is not None else generation_config.output_attentions output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate + else generation_config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) @@ -2862,6 +2874,11 @@ def beam_sample( ['Wie alt bist du?'] ```""" # init values + generation_config = ( + self.generation_config + if self.generation_config is not None + else GenerationConfig.from_model_config(self.config) + ) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: @@ -2871,19 +2888,17 @@ def beam_sample( UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) + pad_token_id = pad_token_id if pad_token_id is not None else generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else generation_config.output_scores + output_attentions = output_attentions if output_attentions is not None else generation_config.output_attentions output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate + else generation_config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) @@ -3175,6 +3190,11 @@ def group_beam_search( ['Wie alt bist du?'] ```""" # init values + generation_config = ( + self.generation_config + if self.generation_config is not None + else GenerationConfig.from_model_config(self.config) + ) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: @@ -3184,19 +3204,17 @@ def group_beam_search( UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) + pad_token_id = pad_token_id if pad_token_id is not None else generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else generation_config.output_scores + output_attentions = output_attentions if output_attentions is not None else generation_config.output_attentions output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate + else generation_config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) @@ -3546,6 +3564,11 @@ def constrained_beam_search( ['Wie alt sind Sie?'] ```""" # init values + generation_config = ( + self.generation_config + if self.generation_config is not None + else GenerationConfig.from_model_config(self.config) + ) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: @@ -3557,19 +3580,17 @@ def constrained_beam_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) + pad_token_id = pad_token_id if pad_token_id is not None else generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else generation_config.output_scores + output_attentions = output_attentions if output_attentions is not None else generation_config.output_attentions output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate + else generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples From df9618acda23ddab63993e28ec65b3ea34bfb0b1 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 25 Nov 2022 10:58:39 +0000 Subject: [PATCH 12/20] fix _get_decoder_start_token_id when called outside GenerationMixin --- src/transformers/generation/utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d694354d4471..a76bddd076bb 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -621,12 +621,15 @@ def _prepare_decoder_input_ids_for_generation( return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: + generation_config = ( + self.generation_config + if self.generation_config is not None + else GenerationConfig.from_model_config(self.config) + ) decoder_start_token_id = ( - decoder_start_token_id - if decoder_start_token_id is not None - else self.generation_config.decoder_start_token_id + decoder_start_token_id if decoder_start_token_id is not None else generation_config.decoder_start_token_id ) - bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id + bos_token_id = bos_token_id if bos_token_id is not None else generation_config.bos_token_id if decoder_start_token_id is not None: return decoder_start_token_id From 3c9072554fb6ce9fe59e53389d0bc98bcae67f9d Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 25 Nov 2022 11:45:16 +0000 Subject: [PATCH 13/20] correct model config load order (set attr > model config > decoder config) --- .../generation/configuration_utils.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 87f2545bbe83..774136d1775d 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -604,13 +604,14 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig" config_dict = model_config.to_dict() config = cls.from_dict(config_dict, return_unused_kwargs=False) - # Special cases: tokens defined in nested parts of the config - for component_name in ("decoder", "generator"): - if component_name in config_dict: - config.bos_token_id = config_dict[component_name]["bos_token_id"] - config.eos_token_id = config_dict[component_name]["eos_token_id"] - config.pad_token_id = config_dict[component_name]["pad_token_id"] - config.decoder_start_token_id = config_dict[component_name]["decoder_start_token_id"] + # Special case: some models have generation attributes set in the decoder. Use them if the attribute is unset. + for decoder_name in ("decoder", "generator"): + if decoder_name in config_dict: + default_generation_config = GenerationConfig() + decoder_config = config_dict[decoder_name] + for attr in config.to_dict().keys(): + if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr): + setattr(config, attr, decoder_config[attr]) return config From 26c1dc004d36c0d845a6adeea2f18a60911d693e Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 25 Nov 2022 12:19:55 +0000 Subject: [PATCH 14/20] update rag to match latest changes --- src/transformers/generation/configuration_utils.py | 3 ++- src/transformers/models/rag/modeling_rag.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 774136d1775d..d37b1ab13457 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -604,7 +604,8 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig" config_dict = model_config.to_dict() config = cls.from_dict(config_dict, return_unused_kwargs=False) - # Special case: some models have generation attributes set in the decoder. Use them if the attribute is unset. + # Special case: some models have generation attributes set in the decoder. Use them if still unset in the + # generation config. for decoder_name in ("decoder", "generator"): if decoder_name in config_dict: default_generation_config = GenerationConfig() diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 461e06ec4f75..4ad37d8ee115 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1459,7 +1459,10 @@ def generate( """ # Handle `generation_config` and kwargs that might update it if generation_config is None: - generation_config = self.generation_config + if self.generation_config is not None: + generation_config = self.generation_config + else: + generation_config = GenerationConfig.from_model_config(self.config) generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs From 09f7ad307f9d58641894cd193c19999c71654e7d Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 13 Dec 2022 16:58:27 +0000 Subject: [PATCH 15/20] Apply suggestions from code review Co-authored-by: Patrick von Platen Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/en/main_classes/text_generation.mdx | 4 ++-- src/transformers/modeling_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/main_classes/text_generation.mdx b/docs/source/en/main_classes/text_generation.mdx index 6226b3c0e05c..a538f26b13a7 100644 --- a/docs/source/en/main_classes/text_generation.mdx +++ b/docs/source/en/main_classes/text_generation.mdx @@ -18,7 +18,7 @@ Each framework has a generate method for auto-regressive text generation impleme - TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`]. - Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`]. -Regardless of your framework of choice, you can parametrize the generate method with a [`~generation.GenerationConfig`] +Regardless of your framework of choice, you can parameterize the generate method with a [`~generation.GenerationConfig`] class instance. Please refer to this class for the complete list of generation parameters, which control the behavior of the generation method. @@ -45,7 +45,7 @@ generation_config.save_pretrained("my_account/my_model", push_to_hub=True) If you inspect a serialized [`~generation.GenerationConfig`] file or print a class instance, you will notice that -default values are omitted. Some attributes, like `max_new_tokens`, have a conservative default value, to avoid running +default values are omitted. Some attributes, like `max_length`, have a conservative default value, to avoid running into resource limitations. Make sure you double-check the defaults in the documentation. diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 04dbb47351a2..c819c6d94c43 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1112,7 +1112,7 @@ def can_generate(self) -> bool: Returns whether this model can generate sequences with `.generate()`. Returns: - :bool: Whether this model can generate sequences with `.generate()`. + `bool`: Whether this model can generate sequences with `.generate()`. """ return hasattr(self, "prepare_inputs_for_generation") From 5ff10a82040ac795d5cf9715bf8da7823a55500a Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 14 Dec 2022 13:27:17 +0000 Subject: [PATCH 16/20] load gen config from model config in model.from_pretrained --- .../en/main_classes/text_generation.mdx | 3 + src/transformers/generation/utils.py | 159 ++++++++---------- src/transformers/modeling_utils.py | 16 +- src/transformers/models/rag/modeling_rag.py | 5 +- 4 files changed, 82 insertions(+), 101 deletions(-) diff --git a/docs/source/en/main_classes/text_generation.mdx b/docs/source/en/main_classes/text_generation.mdx index a538f26b13a7..ce74963f5503 100644 --- a/docs/source/en/main_classes/text_generation.mdx +++ b/docs/source/en/main_classes/text_generation.mdx @@ -68,6 +68,9 @@ my_awesome_config.save_pretrained( generation_config = GenerationConfig.from_pretrained( "my_account/my_model", config_file_name="awesome_generation_config.json" ) + +# Generate with the restored generation configuration +model.generate(inputs, generation_config=generation_config) ``` Finally, you can specify ad hoc modifications to the used generation configuration by passing the attribute you diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index a76bddd076bb..bec214db6144 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -480,6 +480,11 @@ class GenerationMixin: `constraints!=None` or `force_words_ids!=None`. """ + def prepare_inputs_for_generation(self, *args, **kwargs): + raise NotImplementedError( + "A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`." + ) + def _prepare_model_inputs( self, inputs: Optional[torch.Tensor] = None, @@ -621,15 +626,12 @@ def _prepare_decoder_input_ids_for_generation( return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: - generation_config = ( - self.generation_config - if self.generation_config is not None - else GenerationConfig.from_model_config(self.config) - ) decoder_start_token_id = ( - decoder_start_token_id if decoder_start_token_id is not None else generation_config.decoder_start_token_id + decoder_start_token_id + if decoder_start_token_id is not None + else self.generation_config.decoder_start_token_id ) - bos_token_id = bos_token_id if bos_token_id is not None else generation_config.bos_token_id + bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id if decoder_start_token_id is not None: return decoder_start_token_id @@ -1008,7 +1010,7 @@ def generate( generation_config (`~generation.GenerationConfig`, *optional*): The generation configuration to be used as base parametrization for the generation call. `**kwargs` passed to generate matching the attributes of `generation_config` will override them. If - `generation_config` is not provided, the default will be used, which has the following loading + `generation_config` is not provided, the default will be used, which had the following loading priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s default values, whose documentation should be checked to parameterize generation. @@ -1122,12 +1124,8 @@ def generate( # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() - # priority: `generation_config` argument > `model.generation_config` > generation config from `model.config` - if generation_config is None: - if self.generation_config is not None: - generation_config = self.generation_config - else: - generation_config = GenerationConfig.from_model_config(self.config) + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + generation_config = generation_config if generation_config is not None else self.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs self._validate_model_kwargs(model_kwargs.copy()) @@ -1680,25 +1678,22 @@ def contrastive_search( ['DeepMind Company is a company that focuses on the development and commercialization of artificial intelligence (AI). DeepMind’s mission is to help people understand and solve problems that are difficult to solve in the world today.\n\nIn this post, we talk about the benefits of deep learning in business and how it'] ```""" # init values - generation_config = ( - self.generation_config - if self.generation_config is not None - else GenerationConfig.from_model_config(self.config) - ) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - pad_token_id = pad_token_id if pad_token_id is not None else generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id - output_scores = output_scores if output_scores is not None else generation_config.output_scores - output_attentions = output_attentions if output_attentions is not None else generation_config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None - else generation_config.return_dict_in_generate + else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples @@ -2036,11 +2031,6 @@ def greedy_search( ["It might be possible to get a better understanding of the nature of the problem, but it's not"] ```""" # init values - generation_config = ( - self.generation_config - if self.generation_config is not None - else GenerationConfig.from_model_config(self.config) - ) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: @@ -2050,17 +2040,19 @@ def greedy_search( UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id - output_scores = output_scores if output_scores is not None else generation_config.output_scores - output_attentions = output_attentions if output_attentions is not None else generation_config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None - else generation_config.return_dict_in_generate + else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples @@ -2290,11 +2282,6 @@ def sample( ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] ```""" # init values - generation_config = ( - self.generation_config - if self.generation_config is not None - else GenerationConfig.from_model_config(self.config) - ) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: @@ -2305,17 +2292,19 @@ def sample( ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() - pad_token_id = pad_token_id if pad_token_id is not None else generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id - output_scores = output_scores if output_scores is not None else generation_config.output_scores - output_attentions = output_attentions if output_attentions is not None else generation_config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None - else generation_config.return_dict_in_generate + else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples @@ -2548,11 +2537,6 @@ def beam_search( ['Wie alt bist du?'] ```""" # init values - generation_config = ( - self.generation_config - if self.generation_config is not None - else GenerationConfig.from_model_config(self.config) - ) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: @@ -2564,17 +2548,19 @@ def beam_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) - pad_token_id = pad_token_id if pad_token_id is not None else generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id - output_scores = output_scores if output_scores is not None else generation_config.output_scores - output_attentions = output_attentions if output_attentions is not None else generation_config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None - else generation_config.return_dict_in_generate + else self.generation_config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) @@ -2877,11 +2863,6 @@ def beam_sample( ['Wie alt bist du?'] ```""" # init values - generation_config = ( - self.generation_config - if self.generation_config is not None - else GenerationConfig.from_model_config(self.config) - ) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: @@ -2891,17 +2872,19 @@ def beam_sample( UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id - output_scores = output_scores if output_scores is not None else generation_config.output_scores - output_attentions = output_attentions if output_attentions is not None else generation_config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None - else generation_config.return_dict_in_generate + else self.generation_config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) @@ -3193,11 +3176,6 @@ def group_beam_search( ['Wie alt bist du?'] ```""" # init values - generation_config = ( - self.generation_config - if self.generation_config is not None - else GenerationConfig.from_model_config(self.config) - ) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: @@ -3207,17 +3185,19 @@ def group_beam_search( UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id - output_scores = output_scores if output_scores is not None else generation_config.output_scores - output_attentions = output_attentions if output_attentions is not None else generation_config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None - else generation_config.return_dict_in_generate + else self.generation_config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) @@ -3567,11 +3547,6 @@ def constrained_beam_search( ['Wie alt sind Sie?'] ```""" # init values - generation_config = ( - self.generation_config - if self.generation_config is not None - else GenerationConfig.from_model_config(self.config) - ) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: @@ -3583,17 +3558,19 @@ def constrained_beam_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) - pad_token_id = pad_token_id if pad_token_id is not None else generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id - output_scores = output_scores if output_scores is not None else generation_config.output_scores - output_attentions = output_attentions if output_attentions is not None else generation_config.output_attentions + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None - else generation_config.return_dict_in_generate + else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c819c6d94c43..f90e0bdc0c9b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1114,7 +1114,11 @@ def can_generate(self) -> bool: Returns: `bool`: Whether this model can generate sequences with `.generate()`. """ - return hasattr(self, "prepare_inputs_for_generation") + try: + self.prepare_inputs_for_generation() + except NotImplementedError: + return True + return False def get_input_embeddings(self) -> nn.Module: """ @@ -2487,8 +2491,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Set model in evaluation mode to deactivate DropOut modules by default model.eval() - # If it is a model with generation capabilities, attempt to set the generation config to an existing - # `generation_config.json` file. + # If it is a model with generation capabilities, set the generation config if model.can_generate(): try: generation_config = GenerationConfig.from_pretrained( @@ -2505,9 +2508,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P _from_pipeline=from_pipeline, **kwargs, ) - model.generation_config = generation_config - except EnvironmentError: - pass + except OSError: + logger.warning("Generation config not found, creating it from the model config") + generation_config = GenerationConfig.from_model_config(config) + model.generation_config = generation_config # Dispatch model with hooks on all devices if necessary if device_map is not None: diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 4ad37d8ee115..461e06ec4f75 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1459,10 +1459,7 @@ def generate( """ # Handle `generation_config` and kwargs that might update it if generation_config is None: - if self.generation_config is not None: - generation_config = self.generation_config - else: - generation_config = GenerationConfig.from_model_config(self.config) + generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs From 62cd2c6c88a7c66a045728f225cc97148a479c46 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 14 Dec 2022 14:57:03 +0000 Subject: [PATCH 17/20] fix can_generate fn --- src/transformers/modeling_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f90e0bdc0c9b..60ce16bd1cb2 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1114,9 +1114,8 @@ def can_generate(self) -> bool: Returns: `bool`: Whether this model can generate sequences with `.generate()`. """ - try: - self.prepare_inputs_for_generation() - except NotImplementedError: + # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation + if "GenerationMixin" in str(self.prepare_inputs_for_generation): return True return False From e7517ee137fd55205ed0ef628eb18f6e2b0db70c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 14 Dec 2022 15:39:52 +0000 Subject: [PATCH 18/20] handle generate calls without a previous from_pretrained (e.g. tests) --- src/transformers/modeling_utils.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 60ce16bd1cb2..4e08237da54a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1024,7 +1024,7 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): self.config = config self.name_or_path = config.name_or_path self.warnings_issued = {} - self.generation_config = None + self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None def post_init(self): """ @@ -1116,8 +1116,8 @@ def can_generate(self) -> bool: """ # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation if "GenerationMixin" in str(self.prepare_inputs_for_generation): - return True - return False + return False + return True def get_input_embeddings(self) -> nn.Module: """ @@ -2490,10 +2490,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Set model in evaluation mode to deactivate DropOut modules by default model.eval() - # If it is a model with generation capabilities, set the generation config + # If it is a model with generation capabilities, attempt to load the generation config if model.can_generate(): try: - generation_config = GenerationConfig.from_pretrained( + model.generation_config = GenerationConfig.from_pretrained( pretrained_model_name_or_path, cache_dir=cache_dir, force_download=force_download, @@ -2508,9 +2508,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P **kwargs, ) except OSError: - logger.warning("Generation config not found, creating it from the model config") - generation_config = GenerationConfig.from_model_config(config) - model.generation_config = generation_config + logger.warning( + "Generation config file not found, using a generation config created from the model config." + ) + pass # Dispatch model with hooks on all devices if necessary if device_map is not None: From f58dc30fe2c41a064bf759fcfa45491b9d0e8d24 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 14 Dec 2022 17:18:04 +0000 Subject: [PATCH 19/20] add legacy behavior (and a warning) --- .../en/main_classes/text_generation.mdx | 31 ++++++++++++------- .../generation/configuration_utils.py | 13 ++++++-- src/transformers/generation/utils.py | 16 +++++++++- 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/docs/source/en/main_classes/text_generation.mdx b/docs/source/en/main_classes/text_generation.mdx index ce74963f5503..1d00406ac1e5 100644 --- a/docs/source/en/main_classes/text_generation.mdx +++ b/docs/source/en/main_classes/text_generation.mdx @@ -56,21 +56,28 @@ store several generation configurations for a single model (e.g. one for creativ other for summarization with beam search). ```python -from transformers import GenerationConfig +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig -# Create a new generation configuration for later use -my_awesome_config = GenerationConfig(num_beams=8, do_sample=False, bad_words_ids=[[42, 43, 44]], eos_token_id=2) -my_awesome_config.save_pretrained( - "my_account/my_model", config_file_name="awesome_generation_config.json", push_to_hub=True -) +tokenizer = AutoTokenizer.from_pretrained("t5-small") +model = AutoModelForSeq2SeqLM.from_pretrained("t5-small") -# Restore the generation configuration -generation_config = GenerationConfig.from_pretrained( - "my_account/my_model", config_file_name="awesome_generation_config.json" +translation_generation_config = GenerationConfig( + num_beams=4, + early_stopping=True, + decoder_start_token_id=0, + eos_token_id=model.config.eos_token_id, + pad_token=model.config.pad_token_id, ) - -# Generate with the restored generation configuration -model.generate(inputs, generation_config=generation_config) +# If you were working on a model for which your had the right Hub permissions, you could store a named generation +# config as follows +translation_generation_config.save_pretrained("t5-small", "translation_generation_config.json", push_to_hub=True) + +# You could then use the named generation config file to parameterize generation +generation_config = GenerationConfig.from_pretrained("t5-small", "translation_generation_config.json") +inputs = tokenizer("translate English to French: Configuration files are easy to use!", return_tensors="pt") +outputs = model.generate(**inputs, generation_config=generation_config) +print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) +# ['Les fichiers de configuration sont faciles à utiliser !'] ``` Finally, you can specify ad hoc modifications to the used generation configuration by passing the attribute you diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index d37b1ab13457..a477ebe4203c 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -256,12 +256,20 @@ def __init__(self, **kwargs): # Wild card self.generation_kwargs = kwargs.pop("generation_kwargs", {}) - # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the the hub interface. + # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the the hub + # interface. + self._from_model_config = kwargs.pop("_from_model_config", False) self._commit_hash = kwargs.pop("_commit_hash", None) self.transformers_version = kwargs.pop("transformers_version", __version__) def __eq__(self, other): - return self.__dict__ == other.__dict__ + self_dict = self.__dict__.copy() + other_dict = other.__dict__.copy() + # ignore metadata + for metadata_field in ("_from_model_config", "_commit_hash", "transformers_version"): + self_dict.pop(metadata_field, None) + other_dict.pop(metadata_field, None) + return self_dict == other_dict def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" @@ -614,6 +622,7 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig" if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr): setattr(config, attr, decoder_config[attr]) + config._from_model_config = True return config def update(self, **kwargs): diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index bec214db6144..03ad4a25a1d9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1125,7 +1125,21 @@ def generate( self._validate_model_class() # priority: `generation_config` argument > `model.generation_config` (the default generation config) - generation_config = generation_config if generation_config is not None else self.generation_config + if generation_config is None: + # legacy: users may modify the model configuration to control generation -- update the generation config + # model attribute accordingly, if it was created from the model config + if self.generation_config._from_model_config: + new_generation_config = GenerationConfig.from_model_config(self.config) + if new_generation_config != self.generation_config: + warnings.warn( + "You have modified the pretrained model configuration to control generation. This is a" + " deprecated strategy to control generation and will be removed soon, in a future version." + " Please use a generation configuration file (see" + " https://huggingface.co/docs/transformers/main_classes/text_generation)" + ) + self.generation_config = new_generation_config + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs self._validate_model_kwargs(model_kwargs.copy()) From f85e3b2f73ebeb36c02e83684f0465829db10edd Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 14 Dec 2022 19:29:19 +0000 Subject: [PATCH 20/20] lower logger severity --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4e08237da54a..6780f9b19f14 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2508,7 +2508,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P **kwargs, ) except OSError: - logger.warning( + logger.info( "Generation config file not found, using a generation config created from the model config." ) pass