Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def _sanitize_parameters(
max_length=None,
continue_final_message=None,
skip_special_tokens=None,
tokenizer_encode_kwargs=None,
**generate_kwargs,
):
# preprocess kwargs
Expand Down Expand Up @@ -194,6 +195,10 @@ def _sanitize_parameters(

if continue_final_message is not None:
preprocess_params["continue_final_message"] = continue_final_message

if tokenizer_encode_kwargs is not None:
preprocess_params["tokenizer_encode_kwargs"] = tokenizer_encode_kwargs

preprocess_params.update(generate_kwargs)

# forward kwargs
Expand Down Expand Up @@ -288,6 +293,9 @@ def __call__(self, text_inputs, **kwargs):
- `None` : default strategy where nothing in particular happens
- `"hole"`: Truncates left of input, and leaves a gap wide enough to let generation happen (might
truncate a lot of the prompt and not suitable when generation exceed the model capacity)
tokenizer_encode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass along to the encoding step of the tokenizer. If the text input is
a chat, it is passed to `apply_chat_template`. Otherwise, it is passed to `__call__`.
generate_kwargs (`dict`, *optional*):
Additional keyword arguments to pass along to the generate method of the model (see the generate method
corresponding to your framework [here](./text_generation)).
Expand Down Expand Up @@ -333,16 +341,18 @@ def preprocess(
padding=None,
max_length=None,
continue_final_message=None,
tokenizer_encode_kwargs=None,
**generate_kwargs,
):
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
tokenizer_kwargs = {
"add_special_tokens": add_special_tokens,
"truncation": truncation,
"padding": padding,
"max_length": max_length, # TODO: name clash -- this is broken, `max_length` is also a `generate` arg
"max_length": max_length, # NOTE: `max_length` is also a `generate` arg. Use `tokenizer_encode_kwargs` to avoid a name clash
}
tokenizer_kwargs = {key: value for key, value in tokenizer_kwargs.items() if value is not None}
tokenizer_kwargs.update(tokenizer_encode_kwargs or {})

if isinstance(prompt_text, Chat):
tokenizer_kwargs.pop("add_special_tokens", None) # ignore add_special_tokens on chats
Expand Down
18 changes: 18 additions & 0 deletions tests/pipelines/test_pipelines_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import unittest
from unittest.mock import patch

from transformers import (
MODEL_FOR_CAUSAL_LM_MAPPING,
Expand Down Expand Up @@ -555,3 +556,20 @@ def test_pipeline_skip_special_tokens(self):
# forcing special tokens to be included in the output
output = generator(chat, max_new_tokens=1000, do_sample=False, skip_special_tokens=False)
self.assertIn("<end_of_turn>", str(output[0]["generated_text"]))

@require_torch
def test_forward_tokenizer_kwargs(self):
chat = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
]
model = "hf-internal-testing/tiny-gpt2-with-chatml-template"
text_generator = pipeline("text-generation", model, max_new_tokens=5)
tokenizer = text_generator.tokenizer

with patch.object(tokenizer, "apply_chat_template", wraps=tokenizer.apply_chat_template) as mock:
text_generator(chat, tokenizer_encode_kwargs={"enable_thinking": True})
self.assertGreater(mock.call_count, 0)
kw_call_args = mock.call_args[1]
self.assertIn("enable_thinking", kw_call_args)
self.assertEqual(kw_call_args["enable_thinking"], True)