Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
requires_backends,
to_py_obj,
)
from .utils.chat_parsing_utils import recursive_parse
from .utils.chat_template_utils import render_jinja_template
from .utils.import_utils import PROTOBUF_IMPORT_ERROR

Expand Down Expand Up @@ -1521,6 +1522,7 @@ def apply_chat_template(
tools: Optional[list[Union[dict, Callable]]] = None,
documents: Optional[list[dict[str, str]]] = None,
chat_template: Optional[str] = None,
chat_schema: Optional[dict] = None,
add_generation_prompt: bool = False,
continue_final_message: bool = False,
tokenize: bool = True,
Expand Down Expand Up @@ -1556,6 +1558,9 @@ def apply_chat_template(
chat_template (`str`, *optional*):
A Jinja template to use for this conversion. It is usually not necessary to pass anything to this
argument, as the model's template will be used by default.
chat_schema (`dict[str]`, *optional*):
A JSON schema dict with optional parsing fields, used to indicate the model's input spec and allow
parsing of rendered chats.
add_generation_prompt (bool, *optional*):
If this is set, a prompt with the token(s) that indicate
the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model.
Expand Down Expand Up @@ -1618,6 +1623,8 @@ def apply_chat_template(
tokenizer_kwargs = {}

chat_template = self.get_chat_template(chat_template, tools)
if chat_schema is None and getattr(self, "chat_schema", None) is not None:
chat_schema = self.chat_schema

if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
Expand Down Expand Up @@ -1648,6 +1655,18 @@ def apply_chat_template(
**template_kwargs,
)

if return_assistant_tokens_mask and chat_schema:
# TODO This probably needs to be reverted - in cases with tool calls or thinking, we are unlikely to
# correctly capture the assistant boundaries with the regexes, without making them a lot more
# failure-prone. Generation tags in the template are more reliable and should be fully embraced.
generation_indices = [] # This takes priority over jinja generation parsing
for chat in rendered_chat:
parsed_chat = recursive_parse(chat, chat_schema, offset=0)
assistant_messages = [
message["content"] for message in parsed_chat.get("messages", []) if message["role"] == "assistant"
]
generation_indices.append([(message.start, message.end) for message in assistant_messages])

if not is_batched:
rendered_chat = rendered_chat[0]

Expand Down
227 changes: 227 additions & 0 deletions src/transformers/utils/chat_parsing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import json
import re
import jmespath # TODO Make this a proper optional dependency

# Next line only used because eval might grab them. Can be removed once we have something better than eval
from typing import Any, Dict, List, Optional, Tuple, Union # noqa: F401


def _parse_re_match(node_match, require_groups: list[str] | None = None):
if require_groups:
if not node_match.groupdict():
raise ValueError(f"Regex has no named groups, but require_groups was set to {require_groups}")
for group in require_groups:
if group not in node_match.groupdict():
raise ValueError(f"Regex missing required group {group}!\nGroups: {node_match.groupdict().keys()}\n")
# If the regex has named groups, return a dict of those groups
if node_match.groupdict():
return {key: val for key, val in node_match.groupdict().items() if val is not None}
# If the regex has unnamed groups, it MUST only have one, and we return that group
elif groups := list(node_match.groups()):
if len(groups) > 1:
raise ValueError(f"Regex has multiple unnamed groups!\nGroups: {groups}\n")
return groups[0]
# If no groups, use the whole match
else:
return node_match.group(0)


def recursive_parse(
node_content: str | list | dict,
node_schema: dict,
):
"""
This function takes content and a JSON schema node which includes
regex extractors, and recursively parses the content according to the schema. It uses recursion to handle
nested schemas.

Args:
node_content: The content corresponding to this node. Usually a string, but can be something else
if the parent node has multiple capture groups or named groups. In that case,
we generally pass the capture groups straight through to the children of this node
and don't do any parsing at this level.
node_schema: The schema node controlling the parsing.

Returns:
The parsed data structure for the current node.
"""

# If the schema has a const, we just return that value and do absolutely nothing else
if "const" in node_schema:
return node_schema["const"]

# If the node content is None, we return None. EZ.
if node_content is None:
return None

# If not, we have to do a little parsing. First, set some vars and do basic validation
node_type = node_schema["type"]
has_regex = "x-regex" in node_schema or "x-regex-iterator" in node_schema or "x-regex-to-dict" in node_schema
if has_regex and not isinstance(node_content, str):
raise TypeError(
"Schema node got a non-string input, but has a regex for parsing.\n"
f"Input: {node_content}\n"
f"Schema: {node_schema}"
)

node_regex = node_schema.get("x-regex")
node_regex_iterator = node_schema.get("x-regex-iterator")
node_regex_to_dict = node_schema.get("x-regex-to-dict")
if node_regex is not None:
node_match = re.search(node_regex, node_content, flags=re.DOTALL)
if not node_match:
return None # TODO Is this correct? Should I raise an error?
node_content = _parse_re_match(node_match)
if node_regex_iterator is not None:
if node_type != "array":
raise TypeError(f"Schema node with type {node_type} cannot use x-regex-iterator.\nSchema: {node_schema}")
# Note that this can be applied after a standard node-regex search
node_content = [
_parse_re_match(node_match)
for node_match in re.finditer(node_regex_iterator, node_content, flags=re.DOTALL)
]
if not node_content:
return None # TODO Is this correct? Should I raise an error?
if node_regex_to_dict is not None:
if node_type != "object":
raise TypeError(f"Schema node with type {node_type} cannot use x-regex-to-dict.\nSchema: {node_schema}")
# Note that this can be applied after a standard node-regex search
output_content = {}
for node_match in re.finditer(node_regex_to_dict, node_content, flags=re.DOTALL):
match_groups = _parse_re_match(node_match, require_groups=["key", "value"])
output_content[match_groups["key"]] = match_groups["value"]
node_content = output_content
if not node_content:
return None

# Next, if the node has a parser, apply it. We do this after regexes so that the regex can extract
# a substring to parse, if needed.
if "x-parser" in node_schema:
parser = node_schema["x-parser"]
if parser == "json":
if not isinstance(node_content, str):
raise TypeError(
f"Node has JSON parser but got non-string input: {node_content}\nSchema: {node_schema}"
)
parser_args = node_schema.get("x-parser-args", {})
try:
parsed_json = json.loads(node_content)
except json.JSONDecodeError as e:
raise ValueError(
f"Node has JSON parser but could not parse its contents as JSON: {node_content}\nError: {e}"
)
if "transform" in parser_args:
parsed_json = jmespath.search(parser_args["transform"], parsed_json)
node_content = parsed_json
else:
raise ValueError(f"Unknown parser {parser} for schema node: {node_schema}")

# If there's a mapping, apply it now
if "x-mapping" in node_schema:
if not isinstance(node_content, str):
raise TypeError(
f"Schema node with type {node_type} cannot use x-mapping on non-string content.\n"
f"Content: {node_content}\n"
f"Schema: {node_schema}"
)
mapping = node_schema["x-mapping"]
if node_content in mapping:
node_content = mapping[node_content]

if "x-mapping-regex" in node_schema:
if not isinstance(node_content, str):
raise TypeError(
f"Schema node with type {node_type} cannot use x-mapping-regex on non-string content.\n"
f"Content: {node_content}\n"
f"Schema: {node_schema}"
)
mapping_regex = node_schema["x-mapping-regex"]
for pattern, replacement in mapping_regex.items():
node_content = re.sub(pattern, replacement, node_content, flags=re.DOTALL)

# Finally, handle parsed content based on schema type and recurse if required
if node_type == "object":
parsed_schema = {}
if isinstance(node_content, str):
# This means we don't have a regex at this level, so all of our child nodes need to parse the whole
# string themselves to extract their value.
if "properties" not in node_schema:
raise ValueError(
f"Object node received string content but has no regex or parser to handle it.\n"
f"Content: {node_content}\n"
f"Schema: {node_schema}"
)
for key, child_node in node_schema["properties"].items():
child_node_content = recursive_parse(node_content, node_schema["properties"][key])
if child_node_content is not None:
parsed_schema[key] = child_node_content
return parsed_schema
elif isinstance(node_content, dict):
for key, child_node in node_schema.get("properties", {}).items():
# TODO Error if required keys are not present
if key in node_content:
parsed_schema[key] = recursive_parse(node_content[key], child_node)
elif "default" in child_node:
# TODO Do I want to allow defaults?
parsed_schema[key] = child_node["default"]
else:
pass # TODO Add an error for required keys not present
if "additionalProperties" in node_schema:
for key, value in node_content.items():
if key not in node_schema.get("properties", {}):
parsed_schema[key] = recursive_parse(value, node_schema["additionalProperties"])
return parsed_schema
else:
breakpoint()
raise TypeError(f"Expected a dict or str for schema node with type object, got {node_content}")
elif node_type == "array":
if not node_content:
return []
parsed_schema = []
if "items" in node_schema:
if not isinstance(node_content, list):
raise TypeError(f"Expected a list or regex for schema node with type array, got {node_content}")
for item in node_content:
parsed_schema.append(recursive_parse(item, node_schema["items"]))
return parsed_schema
elif "prefixItems" in node_schema:
if not isinstance(node_content, list):
if len(node_schema["prefixItems"]) == 1:
# If there's only one prefix item, this is a single item array, we can just wrap the string
node_content = [node_content]
else:
raise TypeError(f"Expected a list or regex for schema node with type array, got {node_content}")
if len(node_content) != len(node_schema["prefixItems"]):
raise ValueError(
f"Array node has {len(node_content)} items, but schema only has "
f"{len(node_schema['prefixItems'])} prefixItems defined.\n"
f"Content: {node_content}\n"
f"Schema: {node_schema}"
)
for item, item_schema in zip(node_content, node_schema["prefixItems"]):
parsed_schema.append(recursive_parse(item, item_schema))
return parsed_schema
else:
raise ValueError(f"Array node has no items or prefixItems schema defined.\nSchema: {node_schema}")
elif node_type in ("string", "integer", "number", "boolean"):
if not isinstance(node_content, str):
raise TypeError(f"Expected a string for schema node with type {node_type}, got {node_content}")
if node_type == "integer":
return int(node_content)
elif node_type == "number":
return float(node_content)
elif node_type == "boolean":
if node_content.lower() in ("true", "1"):
return True
elif node_content.lower() in ("false", "0"):
return False
else:
raise ValueError(f"Invalid boolean value: {node_content}")
else:
# String type
return node_content
elif node_type == "any":
return node_content
else:
# TODO Should we handle null types?
raise TypeError(f"Unsupported schema type {node_type} for node: {node_content}")
13 changes: 8 additions & 5 deletions src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,13 @@
# Splits the Args: block into individual arguments
args_split_re = re.compile(
r"""
(?:^|\n) # Match the start of the args block, or a newline
\s*(\w+):\s* # Capture the argument name and strip spacing
(.*?)\s* # Capture the argument description, which can span multiple lines, and strip trailing spacing
(?=\n\s*\w+:|\Z) # Stop when you hit the next argument or the end of the block
(?:^|\n) # Start of the args block or a newline
\s*(\w+) # Capture the argument name
(?:\s*\([^)]*\))? # Optional (type) with optional surrounding spaces
\s*:\s* # Colon (allowing spaces around it)
(.*?)\s* # Capture the description (multi-line), trim trailing spaces
(?=\n\s*\w+(?:\s*\([^)]*\))?\s*:|\Z) # Next arg (optionally with type) or end

""",
re.DOTALL | re.VERBOSE,
)
Expand Down Expand Up @@ -101,7 +104,7 @@ def _get_json_schema_type(param_type: type) -> dict[str, str]:
return type_mapping.get(param_type, {"type": "object"})


def _parse_type_hint(hint: str) -> dict:
def _parse_type_hint(hint) -> dict:
origin = get_origin(hint)
args = get_args(hint)

Expand Down
Loading
Loading