Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions src/transformers/generation/flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,10 @@ def generate(
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
>>> input_context = "The dog"
>>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="np").input_ids
>>> inputs = tokenizer(input_context, return_tensors="np")
>>> # generate candidates using sampling
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
>>> outputs = model.generate(**inputs, max_length=20, top_k=30, do_sample=True)
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
```"""
# Validate the `.generate()` call
self._validate_model_class()
Expand All @@ -323,6 +323,17 @@ def generate(
)
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)

if pad_token_id is None and eos_token_id is not None:
if model_kwargs.get("attention_mask") is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id

if decoder_start_token_id is None and self.config.is_encoder_decoder:
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")

Expand Down Expand Up @@ -525,8 +536,8 @@ def _greedy_search(

batch_size, cur_len = input_ids.shape

eos_token_id = jnp.array(eos_token_id)
pad_token_id = jnp.array(pad_token_id)
eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
cur_len = jnp.array(cur_len)

# per batch-item holding current token in loop.
Expand Down Expand Up @@ -614,8 +625,8 @@ def _sample(

batch_size, cur_len = input_ids.shape

eos_token_id = jnp.array(eos_token_id)
pad_token_id = jnp.array(pad_token_id)
eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
cur_len = jnp.array(cur_len)

# per batch-item holding current token in loop.
Expand Down Expand Up @@ -748,8 +759,8 @@ def gather_fn(tensor):

batch_size, num_beams, cur_len = input_ids.shape

eos_token_id = jnp.array(eos_token_id)
pad_token_id = jnp.array(pad_token_id)
eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
cur_len = jnp.array(cur_len)

# per batch,beam-item holding current token in loop.
Expand Down
10 changes: 5 additions & 5 deletions src/transformers/generation/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,11 +702,11 @@ def generate(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
logger.warning(
f"Setting `pad_token_id` to {generation_config.eos_token_id} (first `eos_token_id`) to generate"
" sequence"
)
generation_config.pad_token_id = generation_config.eos_token_id
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
Comment on lines +705 to +709
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took the opportunity also to copy the logic to TF, so it can also handle eos_token_id as a list 👀


use_xla = not tf.executing_eagerly()
if use_xla and not self.supports_xla_generation:
Expand Down
1 change: 1 addition & 0 deletions utils/documentation_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ docs/source/en/model_doc/tapex.mdx
docs/source/en/model_doc/donut.mdx
docs/source/en/model_doc/encoder-decoder.mdx
src/transformers/generation/configuration_utils.py
src/transformers/generation/flax_utils.py
src/transformers/generation/tf_utils.py
src/transformers/generation/utils.py
src/transformers/models/albert/configuration_albert.py
Expand Down