diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 5d936ce5b1dc..9af0797e8897 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -305,10 +305,10 @@ def generate( >>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2") >>> input_context = "The dog" >>> # encode input context - >>> input_ids = tokenizer(input_context, return_tensors="np").input_ids + >>> inputs = tokenizer(input_context, return_tensors="np") >>> # generate candidates using sampling - >>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True) - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + >>> outputs = model.generate(**inputs, max_length=20, top_k=30, do_sample=True) + >>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True) ```""" # Validate the `.generate()` call self._validate_model_class() @@ -323,6 +323,17 @@ def generate( ) prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) + if pad_token_id is None and eos_token_id is not None: + if model_kwargs.get("attention_mask") is None: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + pad_token_id = eos_token_id + if decoder_start_token_id is None and self.config.is_encoder_decoder: raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.") @@ -525,8 +536,8 @@ def _greedy_search( batch_size, cur_len = input_ids.shape - eos_token_id = jnp.array(eos_token_id) - pad_token_id = jnp.array(pad_token_id) + eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None) + pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32) cur_len = jnp.array(cur_len) # per batch-item holding current token in loop. @@ -614,8 +625,8 @@ def _sample( batch_size, cur_len = input_ids.shape - eos_token_id = jnp.array(eos_token_id) - pad_token_id = jnp.array(pad_token_id) + eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None) + pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32) cur_len = jnp.array(cur_len) # per batch-item holding current token in loop. @@ -748,8 +759,8 @@ def gather_fn(tensor): batch_size, num_beams, cur_len = input_ids.shape - eos_token_id = jnp.array(eos_token_id) - pad_token_id = jnp.array(pad_token_id) + eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None) + pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32) cur_len = jnp.array(cur_len) # per batch,beam-item holding current token in loop. diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index 5d619d1d19c2..774389a956f0 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -702,11 +702,11 @@ def generate( "The attention mask and the pad token id were not set. As a consequence, you may observe " "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." ) - logger.warning( - f"Setting `pad_token_id` to {generation_config.eos_token_id} (first `eos_token_id`) to generate" - " sequence" - ) - generation_config.pad_token_id = generation_config.eos_token_id + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + generation_config.pad_token_id = eos_token_id use_xla = not tf.executing_eagerly() if use_xla and not self.supports_xla_generation: diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index 7839f58a2016..ab86425d8a17 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -13,6 +13,7 @@ docs/source/en/model_doc/tapex.mdx docs/source/en/model_doc/donut.mdx docs/source/en/model_doc/encoder-decoder.mdx src/transformers/generation/configuration_utils.py +src/transformers/generation/flax_utils.py src/transformers/generation/tf_utils.py src/transformers/generation/utils.py src/transformers/models/albert/configuration_albert.py