diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 9af0797e8897..cb82c438f799 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -15,6 +15,7 @@ # limitations under the License. +import copy import inspect import warnings from functools import partial @@ -33,6 +34,7 @@ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, ) from ..utils import ModelOutput, logging +from .configuration_utils import GenerationConfig from .flax_logits_process import ( FlaxForcedBOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor, @@ -136,6 +138,11 @@ class FlaxGenerationMixin: `do_sample=False`. """ + def prepare_inputs_for_generation(self, *args, **kwargs): + raise NotImplementedError( + "A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`." + ) + @staticmethod def _run_loop_in_debug(cond_fn, body_fn, init_state): """ @@ -171,7 +178,7 @@ def _validate_model_class(self): Confirms that the model class is compatible with generation. If not, raises an exception that points to the right class to use. """ - if not hasattr(self, "prepare_inputs_for_generation"): + if not self.can_generate(): generate_compatible_mappings = [ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, @@ -211,27 +218,11 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): def generate( self, input_ids: jnp.ndarray, - max_length: Optional[int] = None, - max_new_tokens: Optional[int] = None, - pad_token_id: Optional[int] = None, - bos_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - decoder_start_token_id: Optional[int] = None, - do_sample: Optional[bool] = None, + generation_config: Optional[GenerationConfig] = None, prng_key: Optional[jnp.ndarray] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - num_beams: Optional[int] = None, - no_repeat_ngram_size: Optional[int] = None, - min_length: Optional[int] = None, - forced_bos_token_id: Optional[int] = None, - forced_eos_token_id: Optional[int] = None, - length_penalty: Optional[float] = None, - early_stopping: Optional[bool] = None, trace: bool = True, params: Optional[Dict[str, jnp.ndarray]] = None, - **model_kwargs, + **kwargs, ): r""" Generates sequences of token ids for models with a language modeling head. The method supports the following @@ -246,100 +237,151 @@ def generate( - Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as - defined in the model's config (`config.json`) which in turn defaults to the - [`~modeling_utils.PretrainedConfig`] of the model. + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`. - + For a complete overview of generate, check the [following + guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). - Most of these parameters are explained in more detail in [this blog - post](https://huggingface.co/blog/how-to-generate). + Parameters: input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. - max_length (`int`, *optional*, defaults to `model.config.max_length`): - The maximum length the generated tokens can have. Corresponds to the length of the input prompt + - `max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in - the prompt. - max_new_tokens (`int`, *optional*): - The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. - do_sample (`bool`, *optional*, defaults to `False`): - Whether or not to use sampling ; use greedy decoding otherwise. - temperature (`float`, *optional*, defaults to 1.0): - The value used to module the next token probabilities. - top_k (`int`, *optional*, defaults to 50): - The number of highest probability vocabulary tokens to keep for top-k-filtering. - top_p (`float`, *optional*, defaults to 1.0): - If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher - are kept for generation. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - bos_token_id (`int`, *optional*): - The id of the *beginning-of-sequence* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. - num_beams (`int`, *optional*, defaults to 1): - Number of beams for beam search. 1 means no beam search. - decoder_start_token_id (`int`, *optional*): - If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. trace (`bool`, *optional*, defaults to `True`): Whether to trace generation. Setting `trace=False` should only be used for debugging and will lead to a considerably slower runtime. params (`Dict[str, jnp.ndarray]`, *optional*): Optionally the model parameters can be passed. Can be useful for parallelized generation. - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model - is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs - should be prefixed with *decoder_*. Also accepts `encoder_outputs` to skip encoder part. + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. Return: [`~utils.ModelOutput`]. Examples: + Greedy decoding, using the default generation configuration and ad hoc modifications: + ```python >>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM - >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") - >>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2") - >>> input_context = "The dog" - >>> # encode input context - >>> inputs = tokenizer(input_context, return_tensors="np") - >>> # generate candidates using sampling - >>> outputs = model.generate(**inputs, max_length=20, top_k=30, do_sample=True) + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = FlaxAutoModelForCausalLM.from_pretrained("gpt2") + + >>> prompt = "Today I believe we can finally" + >>> input_ids = tokenizer(prompt, return_tensors="np").input_ids + + >>> # Generate up to 30 tokens + >>> outputs = model.generate(input_ids, do_sample=False, max_length=30) + >>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True) + ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n'] + ``` + + Multinomial sampling, modifying an existing generation configuration: + + ```python + >>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM, GenerationConfig + >>> import numpy as np + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = FlaxAutoModelForCausalLM.from_pretrained("gpt2") + + >>> prompt = "Today I believe we can finally" + >>> input_ids = tokenizer(prompt, return_tensors="np").input_ids + + >>> # Sample up to 30 tokens + >>> generation_config = GenerationConfig.from_pretrained("gpt2") + >>> generation_config.max_length = 30 + >>> generation_config.do_sample = True + >>> outputs = model.generate( + ... input_ids, generation_config=generation_config, prng_key=np.asarray([0, 0], dtype=np.uint32) + ... ) >>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True) + ['Today I believe we can finally get a change in that system. The way I saw it was this: a few years ago, this company would not'] + ``` + + Beam-search decoding, using a freshly initialized generation configuration: + + ```python + >>> from transformers import AutoTokenizer, FlaxAutoModelForSeq2SeqLM, GenerationConfig + + >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") + >>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de") + + >>> sentence = "Paris is one of the densest populated areas in Europe." + >>> input_ids = tokenizer(sentence, return_tensors="np").input_ids + + >>> generation_config = GenerationConfig( + ... max_length=64, + ... num_beams=5, + ... bos_token_id=0, + ... eos_token_id=0, + ... decoder_start_token_id=58100, + ... pad_token_id=58100, + ... bad_words_ids=[[58100]], + ... ) + >>> outputs = model.generate(input_ids, generation_config=generation_config) + >>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True) + ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] ```""" - # Validate the `.generate()` call + # Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() + + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + if generation_config is None: + # legacy: users may modify the model configuration to control generation -- update the generation config + # model attribute accordingly, if it was created from the model config + if self.generation_config._from_model_config: + new_generation_config = GenerationConfig.from_model_config(self.config) + if new_generation_config != self.generation_config: + warnings.warn( + "You have modified the pretrained model configuration to control generation. This is a" + " deprecated strategy to control generation and will be removed soon, in a future version." + " Please use a generation configuration file (see" + " https://huggingface.co/docs/transformers/main_classes/text_generation)" + ) + self.generation_config = new_generation_config + generation_config = self.generation_config + + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs self._validate_model_kwargs(model_kwargs.copy()) # set init values - bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - decoder_start_token_id = ( - decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id - ) prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) - if pad_token_id is None and eos_token_id is not None: + if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: if model_kwargs.get("attention_mask") is None: logger.warning( "The attention mask and the pad token id were not set. As a consequence, you may observe " "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." ) + eos_token_id = generation_config.eos_token_id if isinstance(eos_token_id, list): eos_token_id = eos_token_id[0] logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - pad_token_id = eos_token_id + generation_config.pad_token_id = eos_token_id - if decoder_start_token_id is None and self.config.is_encoder_decoder: + if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder: raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.") # decoder-only models should use left-padding for generation (can't be checked with `trace=True`) if not self.config.is_encoder_decoder and not trace: - if pad_token_id is not None and jnp.sum(input_ids[:, -1] == pad_token_id) > 0: + if ( + generation_config.pad_token_id is not None + and jnp.sum(input_ids[:, -1] == generation_config.pad_token_id) > 0 + ): logger.warning( "A decoder-only architecture is being used, but right-padding was detected! For correct " "generation results, please set `padding_side='left'` when initializing the tokenizer." @@ -350,71 +392,62 @@ def generate( if model_kwargs.get("encoder_outputs") is None: model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs) # prepare decoder_input_ids for generation - input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * generation_config.decoder_start_token_id # Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] - if max_length is None and max_new_tokens is None: + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( - "Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to " - f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is " - "deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend " - "using `max_new_tokens` to control the maximum length of the generation.", + "Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to" + f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" + " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) - elif max_length is None and max_new_tokens is not None: - max_length = max_new_tokens + input_ids_seq_length - elif max_length is not None and max_new_tokens is not None: + elif has_default_max_length and generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + elif not has_default_max_length and generation_config.max_new_tokens is not None: raise ValueError( "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" " limit to the generated output length. Remove one of those arguments. Please refer to the" " documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) - # default to config if still None - max_length = max_length if max_length is not None else self.config.max_length - min_length = min_length if min_length is not None else self.config.min_length - if min_length is not None and min_length > max_length: + if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError( - f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum " - f"length ({max_length})" + f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger than" + f" the maximum length ({generation_config.max_length})" ) - if input_ids_seq_length >= max_length: + if input_ids_seq_length >= generation_config.max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" logger.warning( f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {max_length}. This can lead to unexpected behavior. You should consider increasing" - "`max_new_tokens`." + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing`max_new_tokens`." ) - do_sample = do_sample if do_sample is not None else self.config.do_sample - num_beams = num_beams if num_beams is not None else self.config.num_beams + logits_processor = self._get_logits_processor(generation_config=generation_config) - if not do_sample and num_beams == 1: - logits_processor = self._get_logits_processor( - no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id - ) + if not generation_config.do_sample and generation_config.num_beams == 1: return self._greedy_search( input_ids, - max_length, - pad_token_id, - eos_token_id, + generation_config.max_length, + generation_config.pad_token_id, + generation_config.eos_token_id, logits_processor=logits_processor, trace=trace, params=params, model_kwargs=model_kwargs, ) - elif do_sample and num_beams == 1: - logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature) - logits_processor = self._get_logits_processor( - no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id - ) + elif generation_config.do_sample and generation_config.num_beams == 1: + logits_warper = self._get_logits_warper(generation_config=generation_config) return self._sample( input_ids, - max_length, - pad_token_id, - eos_token_id, + generation_config.max_length, + generation_config.pad_token_id, + generation_config.eos_token_id, prng_key, logits_warper=logits_warper, logits_processor=logits_processor, @@ -422,31 +455,27 @@ def generate( params=params, model_kwargs=model_kwargs, ) - elif not do_sample and num_beams > 1: + elif not generation_config.do_sample and generation_config.num_beams > 1: # broadcast input_ids & encoder_outputs - input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams) + input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams) if "encoder_outputs" in model_kwargs: model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams( - model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams + model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams ) if "attention_mask" in model_kwargs: model_kwargs["attention_mask"] = self._expand_to_num_beams( - model_kwargs["attention_mask"], num_beams=num_beams + model_kwargs["attention_mask"], num_beams=generation_config.num_beams ) - logits_processor = self._get_logits_processor( - no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id - ) - return self._beam_search( input_ids, - max_length, - pad_token_id, - eos_token_id, - length_penalty=length_penalty, - early_stopping=early_stopping, + generation_config.max_length, + generation_config.pad_token_id, + generation_config.eos_token_id, + length_penalty=generation_config.length_penalty, + early_stopping=generation_config.early_stopping, logits_processor=logits_processor, trace=trace, params=params, @@ -455,67 +484,44 @@ def generate( else: raise NotImplementedError("`Beam sampling is currently not implemented.") - def _get_logits_warper( - self, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None - ) -> FlaxLogitsProcessorList: + def _get_logits_warper(self, generation_config: GenerationConfig) -> FlaxLogitsProcessorList: """ This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`] instances used for multinomial sampling. """ - - # init warp parameters - top_k = top_k if top_k is not None else self.config.top_k - top_p = top_p if top_p is not None else self.config.top_p - temperature = temperature if temperature is not None else self.config.temperature - # instantiate warpers list warpers = FlaxLogitsProcessorList() - # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files - # all samplers can be found in `generation_utils_samplers.py` - if temperature is not None and temperature != 1.0: - warpers.append(FlaxTemperatureLogitsWarper(temperature)) - if top_k is not None and top_k != 0: - warpers.append(FlaxTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1)) - if top_p is not None and top_p < 1.0: - warpers.append(FlaxTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1)) + if generation_config.temperature is not None and generation_config.temperature != 1.0: + warpers.append(FlaxTemperatureLogitsWarper(generation_config.temperature)) + if generation_config.top_k is not None and generation_config.top_k != 0: + warpers.append(FlaxTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1)) + if generation_config.top_p is not None and generation_config.top_p < 1.0: + warpers.append(FlaxTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1)) return warpers - def _get_logits_processor( - self, - no_repeat_ngram_size: int, - min_length: int, - max_length: int, - eos_token_id: int, - forced_bos_token_id: int, - forced_eos_token_id: int, - ) -> FlaxLogitsProcessorList: + def _get_logits_processor(self, generation_config: GenerationConfig) -> FlaxLogitsProcessorList: """ This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`] instances used to modify the scores of the language model head. """ processors = FlaxLogitsProcessorList() - # init warp parameters - no_repeat_ngram_size = ( - no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size - ) - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - forced_bos_token_id = ( - forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id - ) - forced_eos_token_id = ( - forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id - ) + if ( + generation_config.min_length is not None + and generation_config.eos_token_id is not None + and generation_config.min_length > -1 + ): + processors.append( + FlaxMinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id) + ) + if generation_config.forced_bos_token_id is not None: + processors.append(FlaxForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id)) + if generation_config.forced_eos_token_id is not None: + processors.append( + FlaxForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id) + ) - # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files - # all samplers can be found in `generation_utils_samplers.py` - if min_length is not None and eos_token_id is not None and min_length > -1: - processors.append(FlaxMinLengthLogitsProcessor(min_length, eos_token_id)) - if forced_bos_token_id is not None: - processors.append(FlaxForcedBOSTokenLogitsProcessor(forced_bos_token_id)) - if forced_eos_token_id is not None: - processors.append(FlaxForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) return processors def _greedy_search( @@ -530,9 +536,9 @@ def _greedy_search( model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, ): # init values - max_length = max_length if max_length is not None else self.config.max_length - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + max_length = max_length if max_length is not None else self.generation_config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id batch_size, cur_len = input_ids.shape @@ -618,9 +624,9 @@ def _sample( model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, ): # init values - max_length = max_length if max_length is not None else self.config.max_length - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + max_length = max_length if max_length is not None else self.generation_config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) batch_size, cur_len = input_ids.shape @@ -716,7 +722,7 @@ def _beam_search( ): """ This beam search function is heavily inspired by Flax's official example: - https://github.com/google/flax/blob/master/examples/wmt/train.py#L254 + https://github.com/google/flax/blob/main/examples/wmt/decode.py """ def flatten_beam_dim(tensor): @@ -751,11 +757,11 @@ def gather_fn(tensor): return jax.tree_util.tree_map(gather_fn, nested) # init values - max_length = max_length if max_length is not None else self.config.max_length - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty - early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping + max_length = max_length if max_length is not None else self.generation_config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + length_penalty = length_penalty if length_penalty is not None else self.generation_config.length_penalty + early_stopping = early_stopping if early_stopping is not None else self.generation_config.early_stopping batch_size, num_beams, cur_len = input_ids.shape diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 69bd5a118b02..a643b43ab67f 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -33,7 +33,7 @@ from .configuration_utils import PretrainedConfig from .dynamic_module_utils import custom_object_save -from .generation import FlaxGenerationMixin +from .generation import FlaxGenerationMixin, GenerationConfig from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict from .utils import ( FLAX_WEIGHTS_INDEX_NAME, @@ -199,6 +199,7 @@ def __init__( self.key = PRNGKey(seed) self.dtype = dtype self.input_shape = input_shape + self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None # To check if the model was intialized automatically. self._is_initialized = _do_init @@ -467,6 +468,16 @@ def load_flax_sharded_weights(cls, shard_files): # the state dict is unflattened to the match the format of model.params return unflatten_dict(state_sharded_dict, sep="/") + def can_generate(self) -> bool: + """ + Returns whether this model can generate sequences with `.generate()`. Returns: + `bool`: Whether this model can generate sequences with `.generate()`. + """ + # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation + if "GenerationMixin" in str(self.prepare_inputs_for_generation): + return False + return True + @classmethod def from_pretrained( cls, @@ -940,6 +951,29 @@ def from_pretrained( "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this." ) + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + except OSError: + logger.info( + "Generation config file not found, using a generation config created from the model config." + ) + pass + if _do_init: # set correct parameters model.params = unflatten_dict(state)