Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,11 +496,13 @@ class TokenizerChatTemplateKwargs(TypedDict, total=False):
the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model.
Note that this argument will be passed to the chat template, and so it must be supported in the
template for this argument to have any effect.
continue_final_message (bool, *optional*):
continue_final_message (bool or str, *optional*):
If this is set, the chat will be formatted so that the final
message in the chat is open-ended, without any EOS tokens. The model will continue this message
rather than starting a new one. This allows you to "prefill" part of
the model's response for it. Cannot be used at the same time as `add_generation_prompt`.
the model's response for it. If a string is passed, it will be used as the key for the field to continue
(e.g. "reasoning_content"). Cannot be used at the same time as `add_generation_prompt`.

return_assistant_tokens_mask (`bool`, defaults to `False`):
Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
the mask will contain 1. For user and system tokens, the mask will contain 0.
Expand All @@ -514,7 +516,7 @@ class TokenizerChatTemplateKwargs(TypedDict, total=False):
tools: list[dict] | None = None
documents: list[dict[str, str]] | None = None
add_generation_prompt: bool | None = False
continue_final_message: bool | None = False
continue_final_message: bool | str | None = False
return_assistant_tokens_mask: bool | None = False
reasoning_effort: str | None = None

Expand Down Expand Up @@ -1681,7 +1683,7 @@ def apply_chat_template(
tools: list[dict] | None = None,
documents: list[dict[str, str]] | None = None,
add_generation_prompt: bool = False,
continue_final_message: bool = False,
continue_final_message: bool | str = False,
return_assistant_tokens_mask: bool = False,
tokenize: bool = False,
return_tensors: str | TensorType | None = None,
Expand Down Expand Up @@ -1905,7 +1907,7 @@ def apply_chat_template(
text=prompt,
images=batch_images if images_exist else None,
videos=batch_videos if videos_exist else None,
audio=batch_audios if batch_audios else None,
audio=batch_audios or None,
**processor_kwargs,
)

Expand All @@ -1932,7 +1934,7 @@ def apply_chat_template(
# Ensure end_pos is also within bounds
if end_pos > len(input_ids[i]):
end_pos = len(input_ids[i])
for token_id in range(start_pos, end_pos if end_pos else len(input_ids[i])):
for token_id in range(start_pos, end_pos or len(input_ids[i])):
current_mask[token_id] = 1
assistant_masks.append(current_mask)
out["assistant_masks"] = assistant_masks
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2994,7 +2994,7 @@ def apply_chat_template(
documents: list[dict[str, str]] | None = None,
chat_template: str | None = None,
add_generation_prompt: bool = False,
continue_final_message: bool = False,
continue_final_message: bool | str = False,
tokenize: bool = True,
padding: bool | str | PaddingStrategy = False,
truncation: bool = False,
Expand Down Expand Up @@ -3031,11 +3031,12 @@ def apply_chat_template(
the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model.
Note that this argument will be passed to the chat template, and so it must be supported in the
template for this argument to have any effect.
continue_final_message (bool, *optional*):
continue_final_message (bool or str, *optional*):
If this is set, the chat will be formatted so that the final
message in the chat is open-ended, without any EOS tokens. The model will continue this message
rather than starting a new one. This allows you to "prefill" part of
the model's response for it. Cannot be used at the same time as `add_generation_prompt`.
the model's response for it. If a string is passed, it will be used as the key for the field to continue
(e.g. "reasoning_content"). Cannot be used at the same time as `add_generation_prompt`.
tokenize (`bool`, defaults to `True`):
Whether to tokenize the output. If `False`, the output will be a string.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Expand Down
19 changes: 14 additions & 5 deletions src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def render_jinja_template(
documents: ChatType | None = None,
chat_template: str | None = None,
return_assistant_tokens_mask: bool = False,
continue_final_message: bool = False,
continue_final_message: bool | str = False,
add_generation_prompt: bool = False,
**kwargs,
) -> str:
Expand Down Expand Up @@ -543,9 +543,17 @@ def render_jinja_template(
chat = chat.messages
if continue_final_message:
chat = deepcopy(chat)
final_message = chat[-1].get("content")
if final_message is None:
raise ValueError("continue_final_message is set but the final message has no content to continue!")
continue_final_message = continue_final_message if isinstance(continue_final_message, str) else "content"

if (final_message := chat[-1].get(continue_final_message)) is None:
raise ValueError(
f'continue_final_message is set but the final message has no "{continue_final_message}" to continue!'
)
if continue_final_message not in chat_template:
raise ValueError(
f'continue_final_message is set to "{continue_final_message}" but this is not an accepted field in the chat_template'
)

elif isinstance(final_message, (list, tuple)):
for content_block in reversed(final_message):
if "text" in content_block:
Expand All @@ -558,7 +566,7 @@ def render_jinja_template(
"continue_final_message is set but we could not find any text to continue in the final message!"
)
else:
chat[-1]["content"] = chat[-1]["content"] + continue_final_message_tag
chat[-1][continue_final_message] = chat[-1][continue_final_message] + continue_final_message_tag
if return_assistant_tokens_mask:
rendered_chat, generation_indices = _render_with_assistant_indices(
compiled_template=compiled_template,
Expand Down Expand Up @@ -586,6 +594,7 @@ def render_jinja_template(
"applying the chat template! This can happen if the chat template deletes portions of "
"the final message. Please verify the chat template and final message in your chat to "
"ensure they are compatible."
f"Final message to continue: {final_message.strip()}\nRendered chat:\n{rendered_chat}"
)
tag_loc = rendered_chat.rindex(continue_final_message_tag.strip())
if rendered_chat[tag_loc : tag_loc + len(continue_final_message_tag)] == continue_final_message_tag:
Expand Down
32 changes: 32 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,6 +1513,38 @@ def test_continue_final_message_with_decoy_earlier_message(self):
"<|im_start|>user\nhi 0<|im_end|>\n<|im_start|>assistant\nbye: 0<|im_end|>\n<|im_start|>user\nhi 1<|im_end|>\n<|im_start|>assistant\nbye:",
)

@require_jinja
def test_continue_final_message_string_and_reasoning(self):
dummy_template = """
{%- for message in messages %}
{{- "<|im_start|>" + message['role'] + "\n" }}
{%- if message['reasoning_content'] is defined %}
{{- "<think>\n" + message['reasoning_content'] + "\n</think>\n" }}
{%- endif %}
{{- message['content'] + "<|im_end|>" + "\n"}}
{%- endfor %}"""
dummy_conversation = [
{"role": "user", "content": "user message"},
{
"role": "assistant",
"reasoning_content": "assistant reasoning...",
"content": "assistant message", # not shown because the continue_final_message is set at "reasoning_content"
},
]
tokenizer = self.get_tokenizer()

# Test continue_final_message="reasoning_content"
prefill_output = tokenizer.apply_chat_template(
dummy_conversation,
chat_template=dummy_template,
tokenize=False,
continue_final_message="reasoning_content",
)
self.assertEqual(
prefill_output,
"<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\n<think>\nassistant reasoning...",
)

@require_jinja
def test_chat_template_dict(self):
dummy_template_1 = "{{'a'}}"
Expand Down
Loading