diff --git a/pyproject.toml b/pyproject.toml index 9092b594..75176bc0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ soundfile = ["soundfile>=0.12.1"] soxr = ["soxr>=0.5.0"] audio = ["mistral_common[soundfile]", "mistral_common[soxr]"] image = ["mistral_common[opencv]"] - +guidance = ["llguidance>=1.3.0,<1.4.0", "jinja2>=3.1.0"] hf-hub = ["huggingface-hub>=1.0"] server = [ "fastapi[standard]>=0.118.3", @@ -41,6 +41,9 @@ server = [ "click>=8.1.0", "uvloop>=0.22.1; python_version >= '3.14'" ] +all = [ + "mistral_common[opencv,sentencepiece,audio,image,guidance,hf-hub,server]" +] [project.scripts] mistral_common = "mistral_common.experimental.app.main:cli" @@ -90,7 +93,7 @@ warn_unused_ignores = true exclude = ["docs", "tools", "build"] [[tool.mypy.overrides]] -module = ["sentencepiece.*", "cv2", "cv2.*","soxr", "soundfile"] +module = ["sentencepiece.*", "cv2", "cv2.*","soxr", "soundfile", "llguidance", "llguidance.*", "jinja2", "jinja2.*"] ignore_missing_imports = true [tool.pytest.ini_options] diff --git a/src/mistral_common/guidance/__init__.py b/src/mistral_common/guidance/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/mistral_common/guidance/data/base_grammar.lark.jinja b/src/mistral_common/guidance/data/base_grammar.lark.jinja new file mode 100644 index 00000000..84fbd433 --- /dev/null +++ b/src/mistral_common/guidance/data/base_grammar.lark.jinja @@ -0,0 +1,25 @@ +{% if json_schema_str != None -%} +{% if json_only -%} +start: SAFE_WS? %json {{ json_schema_str }} +{% else -%} +start: body | SAFE_WS? %json {{ json_schema_str }} +{% endif -%} +{% else -%} +start: body +{% endif -%} + +{% if mode == "auto" -%} +body: content | (content? fcalls) +{% elif mode == "any" -%} +body: fcalls +{% elif mode == "required" -%} +body: content? fcalls +{% elif mode == "none" -%} +body: content +{% endif -%} + +fcalls: {{ fcall }} + +content: (/(.|\n)+/)+ + +SAFE_WS: /[ \t\r\n]{1,8}/ diff --git a/src/mistral_common/guidance/data/plain_text_think_grammar.lark.jinja b/src/mistral_common/guidance/data/plain_text_think_grammar.lark.jinja new file mode 100644 index 00000000..7c98243d --- /dev/null +++ b/src/mistral_common/guidance/data/plain_text_think_grammar.lark.jinja @@ -0,0 +1,29 @@ +{% if json_schema_str != None -%} +{% if json_only -%} +start: SAFE_WS? %json {{ json_schema_str }} +{% else -%} +start: body | SAFE_WS? %json {{ json_schema_str }} +{% endif -%} +{% else -%} +start: body +{% endif -%} + +{% if mode == "auto" -%} +body: think (content | fcalls) +{% elif mode == "any" or mode == "required" -%} +body: think fcalls +{% elif mode == "none" -%} +body: think content +{% endif -%} + +fcalls: {{ fcall }} + +NO_THINK: /(.|\n)+/ & ~/(?s:.*)(<\/think>|)(?s:.*)/ +end_think[capture, lazy]: /(.|\n)*<\/think>/ + +text_first_optional: (NO_THINK)* +content: NO_THINK + +think: SAFE_WS? // text_first_optional end_think + +SAFE_WS: /[ \t\r\n]{1,8}/ diff --git a/src/mistral_common/guidance/data/think_grammar.lark.jinja b/src/mistral_common/guidance/data/think_grammar.lark.jinja new file mode 100644 index 00000000..845285c0 --- /dev/null +++ b/src/mistral_common/guidance/data/think_grammar.lark.jinja @@ -0,0 +1,31 @@ +{% if json_schema_str != None -%} +{% if json_only -%} +start: json_grammar +{% else -%} +start: body | json_grammar +{% endif -%} + +{% if think_with_json -%} +json_grammar: think? SAFE_WS? %json {{ json_schema_str }} +{% else -%} +json_grammar: SAFE_WS? %json {{ json_schema_str }} +{% endif -%} +{% else -%} +start: body +{% endif -%} + +{% if mode == "auto" -%} +body: think? (content | fcalls) +{% elif mode == "any" or mode == "required" -%} +body: think? fcalls +{% elif mode == "none" -%} +body: think? content +{% endif -%} + +fcalls: content? {{ fcall }} + +content: (/(.|\n)+/)+ + +think: {{ begin_think_token }} content {{ end_think_token }} + +SAFE_WS: /[ \t\r\n]{1,8}/ diff --git a/src/mistral_common/guidance/grammar_factory.py b/src/mistral_common/guidance/grammar_factory.py new file mode 100644 index 00000000..c4f251aa --- /dev/null +++ b/src/mistral_common/guidance/grammar_factory.py @@ -0,0 +1,283 @@ +import json +from enum import Enum +from functools import lru_cache +from pathlib import Path +from typing import Any, Callable + +from mistral_common.guidance.tokenizer import from_mistral_tokenizer +from mistral_common.imports import ( + assert_jinja2_installed, + assert_llguidance_installed, + is_jinja2_installed, + is_llguidance_installed, +) +from mistral_common.protocol.instruct.tool_calls import NamedToolChoice, Tool, ToolChoice, ToolChoiceEnum +from mistral_common.tokens.tokenizers.base import SpecialTokens, TokenizerVersion +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from mistral_common.tokens.tokenizers.tekken import is_tekkenizer + +if is_llguidance_installed(): + import llguidance as llg + +if is_jinja2_installed(): + from jinja2 import Template + +JINJA_DIR = Path(__file__).parent / "data" + + +def _validate_mode_and_tools(mode: ToolChoice, tools: list[Tool] | None) -> None: + if isinstance(mode, NamedToolChoice) and all(mode.function.name != tool.function.name for tool in (tools or [])): + raise ValueError( + f"Tool choice requires the {mode.function.name} tool but no tools with this name has been passed." + ) + elif mode in [ToolChoiceEnum.any, ToolChoiceEnum.required] and not tools: + raise ValueError(f"When {mode=} please ensure to pass tools, got {tools=}.") + + +@lru_cache() +def _cached_get_jinja_template(tokenizer_version: TokenizerVersion, reasoning: bool) -> str: + if not reasoning: + jinja_key = _GrammarVariant.base + elif tokenizer_version < TokenizerVersion.v13: + jinja_key = _GrammarVariant.plain_think + else: + jinja_key = _GrammarVariant.think + + return JINJA_PATHS[jinja_key].read_text(encoding="utf-8") + + +@lru_cache() +def _cached_get_lark_from_jinja( + template: str, + mode: str, + fcall: str, + json_schema_str: str | None, + parallel_tool_calls: bool, + json_only: bool, + think_with_json: bool, + begin_think_token: str | None, + end_think_token: str | None, +) -> str: + jinja_template = Template(template) + lark_grammar = jinja_template.render( + mode=mode, + fcall=fcall, + json_schema_str=json_schema_str, + parallel_tool_calls=parallel_tool_calls, + json_only=json_only, + think_with_json=think_with_json, + begin_think_token=begin_think_token, + end_think_token=end_think_token, + ) + return lark_grammar + + +class _GrammarVariant(str, Enum): + base = "base" + plain_think = "plain_think" + think = "think" + + +JINJA_PATHS = { + _GrammarVariant.base: JINJA_DIR / "base_grammar.lark.jinja", + _GrammarVariant.plain_think: JINJA_DIR / "plain_text_think_grammar.lark.jinja", + _GrammarVariant.think: JINJA_DIR / "think_grammar.lark.jinja", +} + + +_TOOL_CALL_GRAMMAR = "{tool_calls_token} SAFE_WS? {tool_name} {args_token} SAFE_WS? %json {args_json} SAFE_WS?" + + +def _get_tool_args_json(tool: Tool) -> dict[str, Any]: + r"""Returns the JSON schema for a tool's arguments.""" + args = tool.function.parameters if tool.function.strict else {"type": "object"} + return args or {"type": "object", "properties": {}, "additionalProperties": False} + + +def _convert_tool_calls( + tools: list[Tool] | None, + mode: ToolChoice, + parallel_tool_calls: bool, + get_special_token_id: Callable[[str], str], +) -> str: + r"""Converts tool calls to a lark grammar string. + + Args: + tools: The list of tools available. + mode: The tool choice mode. + parallel_tool_calls: Whether parallel tool calls are allowed. + get_special_token_id: Callable that maps a special token name to its lark grammar syntax. + + Returns: + The lark grammar string for tool calls. + """ + if mode == ToolChoiceEnum.none: + return "" + + tool_calls_token = get_special_token_id(SpecialTokens.tool_calls.value) + args_token = get_special_token_id(SpecialTokens.args.value) + + any_strict_true = any(tool.function.strict for tool in tools) if tools else False + + if not tools or not any_strict_true: + tool_name = f'"{mode.function.name}"' if isinstance(mode, NamedToolChoice) else "/.+/" + tool_entries = [(tool_name, '{"type": "object"}')] + else: + filtered_tools = ( + [next(tool for tool in tools if tool.function.name == mode.function.name)] + if isinstance(mode, NamedToolChoice) + else tools + ) + tool_entries = [ + (f'"{tool.function.name}"', json.dumps(_get_tool_args_json(tool), ensure_ascii=False)) + for tool in filtered_tools + ] + + grammar_parts = [ + _TOOL_CALL_GRAMMAR.format( + tool_calls_token=tool_calls_token, + args_token=args_token, + tool_name=name, + args_json=args_json, + ) + for name, args_json in tool_entries + ] + + grammar_tool_call = ( + " | ".join(f"({part})" for part in grammar_parts) if len(grammar_parts) > 1 else grammar_parts[0] + ) + + return f"({grammar_tool_call})+" if parallel_tool_calls else grammar_tool_call + + +class GrammarFactory: + r"""Generates grammars for a given tokenizer.""" + + @staticmethod + def is_supported(tokenizer: MistralTokenizer) -> bool: + r"""Checks whether the given tokenizer is supported by guidance. + + Guidance requires a Tekken tokenizer with version >= v11. + + Args: + tokenizer: The Mistral tokenizer to check. + + Returns: + Whether the tokenizer is supported. + """ + inner = tokenizer.instruct_tokenizer.tokenizer + return is_tekkenizer(inner) and not inner.version < TokenizerVersion.v11 + + def __init__(self, tokenizer: MistralTokenizer) -> None: + r"""Initialize the grammar factory. + + Args: + tokenizer: The Mistral tokenizer to generate grammars for. + + Raises: + ValueError: If the tokenizer is not supported (see + [`is_supported`][mistral_common.guidance.grammar_factory.GrammarFactory.is_supported]). + """ + assert_llguidance_installed() + assert_jinja2_installed() + self._tokenizer = tokenizer.instruct_tokenizer.tokenizer + if not self.is_supported(tokenizer): + raise ValueError( + f"Guidance requires a Tekken tokenizer with version >= v11, " + f"got {type(self._tokenizer).__name__} {self._tokenizer.version.value}" + ) + self._llg_tokenizer = from_mistral_tokenizer(tokenizer) + self._special_token_map = self._build_special_token_map() + + def _build_special_token_map(self) -> dict[str, str]: + r"""Build a mapping from special token names to their grammar syntax.""" + return {self._tokenizer.id_to_piece(i): f"<[{i}]>" for i in range(self._tokenizer.num_special_tokens)} + + def _special_token_lark(self, token_name: str) -> str: + r"""Convert special token name to lark grammar syntax.""" + assert token_name in self._special_token_map, f"Unknown special token: {token_name}" + return self._special_token_map[token_name] + + def _get_optional_special_token_lark(self, token_name: str) -> str | None: + r"""Returns lark grammar syntax for a special token, or `None` if absent.""" + return self._special_token_map.get(token_name) + + @property + def llg_tokenizer(self) -> "llg.LLTokenizer": + return self._llg_tokenizer + + def select_jinja_template(self, reasoning: bool) -> str: + r"""Selects and returns the appropriate jinja template content based on tokenizer version and reasoning mode. + + Args: + reasoning: Whether reasoning/thinking mode is enabled. + + Returns: + The jinja template content as a string. + """ + return _cached_get_jinja_template(tokenizer_version=self._tokenizer.version, reasoning=reasoning) + + def get_lark_from_jinja( + self, + template: str, + mode: ToolChoice, + tools: list[Tool] | None, + json_schema: dict[str, Any] | None, + parallel_tool_calls: bool, + json_only: bool = False, + ) -> str: + r"""Renders a lark grammar from a jinja template. + + Args: + template: Jinja template to render as a string. + mode: The function calling mode (auto, any, none). + tools: The list of tools available. + json_schema: JSON schema to additionally allow, unioned with the grammar. + parallel_tool_calls: Whether parallel tool calls are allowed. + json_only: If True, generates only JSON schema grammar without text/tool call alternatives. + + Returns: + The rendered lark grammar string. + """ + # Verifies that the NamedToolChoice has a valid tool and "any", "required" has tools. + _validate_mode_and_tools(mode=mode, tools=tools) + + fcall = _convert_tool_calls(tools, mode, parallel_tool_calls, self._special_token_lark) + json_schema_str = json.dumps(json_schema, ensure_ascii=False) if json_schema else None + # NamedToolChoice forces a specific tool, which maps to "required" grammar. + template_mode = ToolChoiceEnum.required if isinstance(mode, NamedToolChoice) else ToolChoiceEnum(mode) + think_with_json = self._tokenizer.version.supports_model_settings + + begin_think_token = self._get_optional_special_token_lark(SpecialTokens.begin_think.value) + end_think_token = self._get_optional_special_token_lark(SpecialTokens.end_think.value) + + return _cached_get_lark_from_jinja( + template=template, + mode=template_mode.value, + fcall=fcall, + json_schema_str=json_schema_str, + parallel_tool_calls=parallel_tool_calls, + json_only=json_only, + think_with_json=think_with_json, + begin_think_token=begin_think_token, + end_think_token=end_think_token, + ) + + def get_lark_for_json_schema(self, template: str, json_schema: dict[str, Any]) -> str: + r"""Returns a lark grammar that only accepts JSON objects matching the given schema. + + Args: + template: Jinja template to render as a string. + json_schema: The JSON schema to validate against. + + Returns: + The rendered lark grammar string that only matches the given JSON schema. + """ + return self.get_lark_from_jinja( + template=template, + mode=ToolChoiceEnum.none, + tools=None, + json_schema=json_schema, + parallel_tool_calls=True, + json_only=True, + ) diff --git a/src/mistral_common/guidance/tokenizer.py b/src/mistral_common/guidance/tokenizer.py new file mode 100644 index 00000000..005e2ba8 --- /dev/null +++ b/src/mistral_common/guidance/tokenizer.py @@ -0,0 +1,109 @@ +import re + +from mistral_common.imports import assert_llguidance_installed, is_llguidance_installed +from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, Tokenizer +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from mistral_common.tokens.tokenizers.tekken import is_tekkenizer + +if is_llguidance_installed(): + import llguidance as llg + + +class MistralLLGTokenizer: + r"""Wraps a Tekken tokenizer for use with llguidance.""" + + @property + def bos_token_id(self) -> int: + r"""The beginning of string token id.""" + return self._tokenizer.bos_id + + @property + def eos_token_id(self) -> int: + r"""The end of string token id.""" + return self._tokenizer.eos_id + + @property + def tokens(self) -> list[bytes]: + r"""The list of token byte representations.""" + return self._tokens + + @property + def special_token_ids(self) -> list[int]: + r"""The list of special token ids.""" + return self._special_token_ids + + def __init__(self, tokenizer: Tokenizer) -> None: + r"""Initialize the wrapper. + + Args: + tokenizer: The Tekken tokenizer to wrap for llguidance compatibility. + + Raises: + TypeError: If the tokenizer is not a Tekkenizer. + ValueError: If a special token has an invalid format. + """ + assert_llguidance_installed() + + if not is_tekkenizer(tokenizer): + raise TypeError(f"Guidance only supports Tekken tokenizers, got {type(tokenizer)}") + + self._tokenizer = tokenizer + self._tokens: list[bytes] = [] + self._special_token_ids: list[int] = [] + + seen_special_tokens: set[str] = set() + for i in range(self._tokenizer.n_words): + # Convert square brackets to angle brackets for special tokens, + # since llg only recognizes the latter. + if i < self._tokenizer.num_special_tokens: + token_rep = self._tokenizer.id_to_piece(i) + if match := re.fullmatch(r"\[(.*)\]", token_rep): + token_rep_llg = f"<{match.group(1)}>" + else: + token_rep_llg = token_rep + + if not re.fullmatch(r"<.*>", token_rep_llg): + raise ValueError(f"Invalid special token: {token_rep_llg} ({token_rep})") + if token_rep_llg in seen_special_tokens: + raise ValueError(f"Duplicate special token: {token_rep_llg} (already seen: {seen_special_tokens})") + seen_special_tokens.add(token_rep_llg) + self._special_token_ids.append(i) + self._tokens.append(token_rep_llg.encode("utf-8")) + else: + token_bytes = self._tokenizer.id_to_byte_piece(i, SpecialTokenPolicy.RAISE) + self._tokens.append(token_bytes) + + if len(self._special_token_ids) != self._tokenizer.num_special_tokens: + raise ValueError( + f"Expected {self._tokenizer.num_special_tokens} special tokens, but found " + f"{len(self._special_token_ids)}" + ) + + def __call__(self, s: str) -> list[int]: + r"""Tokenizes a string into token ids. + + Args: + s: The string to tokenize. + + Returns: + The list of token ids. + """ + return self._tokenizer.encode(s, bos=False, eos=False) + + +def from_mistral_tokenizer(tokenizer: MistralTokenizer) -> "llg.LLTokenizer": + r"""Creates an llguidance tokenizer from a Mistral tokenizer. + + Args: + tokenizer: The Mistral tokenizer to convert. Must wrap a Tekkenizer. + + Returns: + The llguidance tokenizer. + + Raises: + TypeError: If the underlying tokenizer is not a Tekkenizer. + """ + assert_llguidance_installed() + inner_tokenizer = tokenizer.instruct_tokenizer.tokenizer + tokenizer_data = MistralLLGTokenizer(inner_tokenizer) + return llg.LLTokenizer(llg.TokenizerWrapper(tokenizer_data)) diff --git a/src/mistral_common/imports.py b/src/mistral_common/imports.py index 98047e36..820ef541 100644 --- a/src/mistral_common/imports.py +++ b/src/mistral_common/imports.py @@ -24,6 +24,16 @@ def is_hf_hub_installed() -> bool: return is_package_installed("huggingface_hub") +@lru_cache() +def is_jinja2_installed() -> bool: + return is_package_installed("jinja2") + + +@lru_cache() +def is_llguidance_installed() -> bool: + return is_package_installed("llguidance") + + @lru_cache() def is_opencv_installed() -> bool: try: @@ -59,21 +69,36 @@ def is_soxr_installed() -> bool: return is_package_installed("soxr") +@lru_cache() def assert_hf_hub_installed() -> None: assert_package_installed("huggingface_hub", _get_dependency_error_message("huggingface_hub", "hf-hub")) +@lru_cache() +def assert_jinja2_installed() -> None: + assert_package_installed("jinja2", _get_dependency_error_message("jinja2", "guidance")) + + +@lru_cache() +def assert_llguidance_installed() -> None: + assert_package_installed("llguidance", _get_dependency_error_message("llguidance", "guidance")) + + +@lru_cache() def assert_opencv_installed() -> None: assert_package_installed("cv2", _get_dependency_error_message("opencv", "opencv")) +@lru_cache() def assert_sentencepiece_installed() -> None: assert_package_installed("sentencepiece", _get_dependency_error_message("sentencepiece", "sentencepiece")) +@lru_cache() def assert_soundfile_installed() -> None: assert_package_installed("soundfile", _get_dependency_error_message("soundfile", "soundfile")) +@lru_cache() def assert_soxr_installed() -> None: assert_package_installed("soxr", _get_dependency_error_message("soxr", "soxr")) diff --git a/src/mistral_common/tokens/tokenizers/sentencepiece.py b/src/mistral_common/tokens/tokenizers/sentencepiece.py index ada1d0c3..a7325af7 100644 --- a/src/mistral_common/tokens/tokenizers/sentencepiece.py +++ b/src/mistral_common/tokens/tokenizers/sentencepiece.py @@ -3,6 +3,7 @@ import warnings from functools import cached_property from pathlib import Path +from typing import TypeGuard import numpy as np @@ -259,3 +260,8 @@ def pad_id(self) -> int: def unk_id(self) -> int: r"""The unknown token id.""" return self._model.unk_id() # type: ignore + + +def is_sentencepiece_tokenizer(tokenizer: Tokenizer) -> TypeGuard[SentencePieceTokenizer]: + r"""Returns whether the tokenizer is a SentencePieceTokenizer.""" + return isinstance(tokenizer, SentencePieceTokenizer) diff --git a/src/mistral_common/tokens/tokenizers/tekken.py b/src/mistral_common/tokens/tokenizers/tekken.py index 27b52cf3..85f27926 100644 --- a/src/mistral_common/tokens/tokenizers/tekken.py +++ b/src/mistral_common/tokens/tokenizers/tekken.py @@ -5,7 +5,7 @@ from functools import cached_property from itertools import groupby from pathlib import Path -from typing import TypedDict +from typing import TypedDict, TypeGuard import numpy as np import tiktoken @@ -545,3 +545,8 @@ def _reload_mergeable_ranks( assert set(ranks.values()) == set(range(len(ranks))) return ranks + + +def is_tekkenizer(tokenizer: Tokenizer) -> TypeGuard[Tekkenizer]: + r"""Returns whether the tokenizer is a Tekkenizer.""" + return isinstance(tokenizer, Tekkenizer) diff --git a/tests/data/emoji.lark b/tests/data/emoji.lark new file mode 100644 index 00000000..6222e8f0 --- /dev/null +++ b/tests/data/emoji.lark @@ -0,0 +1,2 @@ +start: emoji+ +emoji: /(\p{Emoji_Presentation}|\p{Emoji}\uFE0F)/ diff --git a/tests/guidance/__init__.py b/tests/guidance/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/guidance/test_guidance.py b/tests/guidance/test_guidance.py new file mode 100644 index 00000000..70d9c856 --- /dev/null +++ b/tests/guidance/test_guidance.py @@ -0,0 +1,1685 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Literal + +import llguidance as llg +import pytest +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from mistral_common.guidance.grammar_factory import GrammarFactory, _convert_tool_calls +from mistral_common.protocol.instruct.chunk import TextChunk, ThinkChunk +from mistral_common.protocol.instruct.messages import AssistantMessage +from mistral_common.protocol.instruct.normalize import get_normalizer +from mistral_common.protocol.instruct.request import ReasoningEffort +from mistral_common.protocol.instruct.tool_calls import ( + Function, + FunctionCall, + FunctionName, + NamedToolChoice, + Tool, + ToolCall, + ToolChoice, + ToolChoiceEnum, + ToolTypes, +) +from mistral_common.protocol.instruct.validator import ValidationMode, get_validator +from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, Tokenizer, TokenizerVersion +from mistral_common.tokens.tokenizers.instruct import ( + InstructTokenizerBase, + InstructTokenizerV11, + InstructTokenizerV13, + InstructTokenizerV15, +) +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from mistral_common.tokens.tokenizers.model_settings_builder import EnumBuilder, ModelSettingsBuilder +from mistral_common.tokens.tokenizers.tekken import Tekkenizer, is_tekkenizer +from tests.test_tekken import get_special_tokens, quick_vocab + +EMOJI_LARK_PATH = Path(__file__).parent.parent / "data" / "emoji.lark" + +Mode = Literal["auto", "any", "none", "required"] +_AUTO_ANY_REQUIRED: tuple[Mode, Mode, Mode] = ("auto", "any", "required") + + +_NUM_SPECIAL_TOKENS = 100 +_EXTRA_TOKENS = [ + b"de", + b"he", + b"llo", + "πŸ˜ƒ".encode("utf-8"), + "πŸ˜‚".encode("utf-8"), + "😊".encode("utf-8"), + "😍".encode("utf-8"), + "😘".encode("utf-8"), + "πŸ˜—".encode("utf-8"), + "πŸ˜™".encode("utf-8"), + "😚".encode("utf-8"), + "πŸ˜‹".encode("utf-8"), + "πŸ˜›".encode("utf-8"), + "😜".encode("utf-8"), + "😝".encode("utf-8"), + "πŸ€‘".encode("utf-8"), + "πŸ€—".encode("utf-8"), + "πŸ€”".encode("utf-8"), + "🀐".encode("utf-8"), + "😐".encode("utf-8"), + "πŸ˜‘".encode("utf-8"), + "😢".encode("utf-8"), + "😬".encode("utf-8"), + "こ".encode("utf-8"), + "γ‚“".encode("utf-8"), + "に".encode("utf-8"), + "け".encode("utf-8"), + "は".encode("utf-8"), + "Ω…Ψ±Ψ­Ψ¨Ψ§".encode("utf-8"), + "Ψ¨ΩƒΩ…".encode("utf-8"), + "في".encode("utf-8"), + "ΨΉΨ§Ω„Ω…".encode("utf-8"), + "Ψ§Ω„Ψ°ΩƒΨ§Ψ‘".encode("utf-8"), + "Ψ§Ω„Ψ§Ψ΅Ψ·Ω†Ψ§ΨΉΩŠ".encode("utf-8"), +] + + +def _build_tekken_mistral_tokenizer( + version: TokenizerVersion, + add_think: bool = False, + model_settings_builder: ModelSettingsBuilder | None = None, +) -> MistralTokenizer: + r"""Builds a MistralTokenizer wrapping a programmatic Tekkenizer.""" + special_tokens = get_special_tokens(version, add_think=add_think) + vocab = quick_vocab(_EXTRA_TOKENS) + + tekkenizer = Tekkenizer( + vocab, + special_tokens=special_tokens, + pattern=r"(?s:.+)", + vocab_size=len(vocab) + _NUM_SPECIAL_TOKENS, + num_special_tokens=_NUM_SPECIAL_TOKENS, + version=version, + model_settings_builder=model_settings_builder, + ) + + match version: + case TokenizerVersion.v11: + instruct_tokenizer = InstructTokenizerV11(tekkenizer) + case TokenizerVersion.v13: + instruct_tokenizer = InstructTokenizerV13(tekkenizer) + case TokenizerVersion.v15: + instruct_tokenizer = InstructTokenizerV15(tekkenizer) + case _: + raise ValueError(f"Unsupported version for programmatic Tekken build: {version}") + + normalizer = get_normalizer(version, tekkenizer.model_settings_builder) + validator = get_validator(version, mode=ValidationMode.test) + return MistralTokenizer(instruct_tokenizer, validator=validator, request_normalizer=normalizer) + + +@pytest.fixture(scope="module") +def v11_tekken() -> MistralTokenizer: + return _build_tekken_mistral_tokenizer(TokenizerVersion.v11) + + +@pytest.fixture(scope="module") +def v13_tekken() -> MistralTokenizer: + return _build_tekken_mistral_tokenizer(TokenizerVersion.v13, add_think=True) + + +_V15_MODEL_SETTINGS_BUILDER = ModelSettingsBuilder( + reasoning_effort=EnumBuilder[ReasoningEffort]( + values=list(ReasoningEffort), + accepts_none=True, + default=None, + ), +) + + +@pytest.fixture(scope="module") +def v15_tekken() -> MistralTokenizer: + return _build_tekken_mistral_tokenizer( + TokenizerVersion.v15, add_think=True, model_settings_builder=_V15_MODEL_SETTINGS_BUILDER + ) + + +_PAYMENT_PARAMS: dict[str, Any] = { + "type": "object", + "additionalProperties": False, + "properties": {"transaction_id": {"type": "string", "description": "The transaction id."}}, + "required": ["transaction_id"], +} + + +class ToolProvider: + @staticmethod + def retrieve_payment_status(strict: bool) -> Tool: + return Tool( + function=Function( + name="retrieve_payment_status", + description="Get payment status of a transaction", + strict=strict, + parameters=_PAYMENT_PARAMS, + ) + ) + + @staticmethod + def retrieve_payment_date(strict: bool) -> Tool: + return Tool( + function=Function( + name="retrieve_payment_date", + description="Get payment date of a transaction", + strict=strict, + parameters=_PAYMENT_PARAMS, + ) + ) + + +class SchemaProvider: + @staticmethod + def basic_person() -> dict[str, Any]: + class Person(BaseModel): + model_config = ConfigDict(extra="forbid") + name: str + age: int + + return Person.model_json_schema() + + @staticmethod + def basic_dict_of_list() -> dict[str, Any]: + class DoMerge(BaseModel): + model_config = ConfigDict(extra="forbid") + new_clusters: dict[str, list[str]] = Field(default_factory=dict) + + return DoMerge.model_json_schema() + + +class TestCase(BaseModel): + __test__ = False + model_config = ConfigDict(arbitrary_types_allowed=True) + tokenizer: Tokenizer + mode: Literal["auto", "any", "none", "required"] | NamedToolChoice + tokens: list[int] + should_fail_on: int | None + case_name: str + reasoning: bool = False + parallel_tool_calls: bool = True + tools: list[Tool] | None = None + json_schema: dict[str, Any] | None = None + # When set, uses this raw lark grammar instead of GrammarFactory.get_lark + raw_lark: str | None = None + + @model_validator(mode="after") + def validate_should_fail_on(self) -> TestCase: + if self.should_fail_on is not None and self.should_fail_on < 0: + self.should_fail_on += len(self.tokens) + if self.should_fail_on < 0: + raise ValueError( + f"should_fail_on={self.should_fail_on + len(self.tokens)} " + f"is out of bounds for tokens of length {len(self.tokens)}" + ) + return self + + @property + def name(self) -> str: + return f"tekken_{self.tokenizer.version.value}_{self.mode}_{self.case_name}" + + +def _encode_content( + instruct_tokenizer: InstructTokenizerBase, + content: str | list[Any], +) -> list[int]: + tokenizer = instruct_tokenizer.tokenizer + + if isinstance(content, str): + return instruct_tokenizer.encode_assistant_message( + AssistantMessage(content=content), is_before_last_user_message=False, continue_message=False + ) + + tool_calls = [x for x in content if isinstance(x, ToolCall)] + content_chunks = [x for x in content if not isinstance(x, ToolCall)] + + tokens: list[int] = [] + if content_chunks: + tokens = instruct_tokenizer.encode_assistant_message( + AssistantMessage(content=content_chunks), + is_before_last_user_message=False, + continue_message=False, + ) + # The instruct tokenizer appends EOS after content, but when tool calls follow, + # the EOS should come after the last tool call, not after the content. Strip it + # here so the tool call tokens are appended directly after content tokens. + while tokens and tokens[-1] == tokenizer.eos_id: + tokens.pop() + + for tc in tool_calls: + tokens += [ + tokenizer.get_special_token("[TOOL_CALLS]"), + *tokenizer.encode(tc.function.name, bos=False, eos=False), + tokenizer.get_special_token("[ARGS]"), + *tokenizer.encode(tc.function.arguments, bos=False, eos=False), + ] + tokens.append(tokenizer.eos_id) + return tokens + + +def _find_first_rejection( + factory: GrammarFactory, + tokens: list[int], + mode: ToolChoice, + tools: list[Tool] | None, +) -> int: + r"""Finds the index of the first token rejected by the grammar. + + Args: + factory: The grammar factory. + tokens: The token sequence to test. + mode: The tool choice mode (literal or NamedToolChoice). + tools: The tools to pass to grammar generation. + + Returns: + The index of the first rejected token. + + Raises: + ValueError: If all tokens are accepted. + """ + template = factory.select_jinja_template(reasoning=False) + grammar = factory.get_lark_from_jinja( + template=template, mode=mode, tools=tools, json_schema=None, parallel_tool_calls=True + ) + matcher = llg.LLMatcher(factory.llg_tokenizer, grammar) + for i, token in enumerate(tokens): + if not matcher.consume_token(token): + return i + raise ValueError("All tokens were accepted β€” expected a rejection") + + +def _token_debug_repr(tokenizer: Tekkenizer, token_id: int) -> str: + return repr(tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)) + + +def _generate_general_cases(mistral_tokenizer: MistralTokenizer) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + assert isinstance(tokenizer, Tekkenizer) + + cases: list[TestCase] = [] + items = { + "newline": "\n", + "blank": "_", + "text": "Hello!", + "text_with_newlines": "Hello!\n\nHow are you?\nI'm fine, thanks!", + "emojis": "πŸ˜ƒπŸ˜‚πŸ˜ŠπŸ˜πŸ˜˜πŸ˜—πŸ˜™πŸ˜šπŸ˜‹πŸ˜›πŸ˜œπŸ˜πŸ€‘πŸ€—πŸ€”πŸ€πŸ˜πŸ˜‘πŸ˜ΆπŸ˜¬", + "japanese": "こんにけは", + "arabic": "Ω…Ψ±Ψ­Ψ¨Ψ§ Ψ¨ΩƒΩ… في ΨΉΨ§Ω„Ω… Ψ§Ω„Ψ°ΩƒΨ§Ψ‘ Ψ§Ω„Ψ§Ψ΅Ψ·Ω†Ψ§ΨΉΩŠ", + } + for case_name, content in items.items(): + tokens = _encode_content(instruct_tokenizer, content) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + case_name=case_name, + should_fail_on=None, + mode="auto", + reasoning=False, + ) + ) + if tokenizer.version < TokenizerVersion.v13: + # Count how many leading whitespace-only tokens the SAFE_WS? rule will consume + # before the grammar rejects (expecting ). + content_tokens = tokenizer.encode(content, bos=False, eos=False) + ws_prefix_len = 0 + for t in content_tokens: + piece = tokenizer.id_to_byte_piece(t, SpecialTokenPolicy.IGNORE) + if piece.strip(b" \t\r\n") == b"": + ws_prefix_len += 1 + else: + break + reasoning_fail = ws_prefix_len + else: + reasoning_fail = None + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + case_name=f"{case_name}_reasoning", + should_fail_on=reasoning_fail, + mode="auto", + reasoning=True, + ) + ) + return cases + + +def _count_prefix_tokens(tokenizer: Tokenizer, full_text: str, prefix: str) -> int: + r"""Counts the number of tokens that cover the prefix bytes in the full-text tokenization. + + BPE tokenization is context-dependent, so encoding a prefix in isolation may produce + different tokens than encoding the full string. This helper encodes the full string + and counts how many tokens are needed to cover the byte-length of the prefix. + + Args: + tokenizer: The tokenizer to use. + full_text: The complete text to tokenize. + prefix: The prefix whose byte-length determines the token count. + + Returns: + The number of tokens covering the prefix bytes. + """ + assert is_tekkenizer(tokenizer) + prefix_byte_len = len(prefix.encode("utf-8")) + tokens = tokenizer.encode(full_text, bos=False, eos=False) + byte_count = 0 + for i, t in enumerate(tokens): + byte_count += len(tokenizer.id_to_byte_piece(t, SpecialTokenPolicy.IGNORE)) + if byte_count >= prefix_byte_len: + return i + 1 + return len(tokens) + + +def _generate_emoji_cases(mistral_tokenizer: MistralTokenizer) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + emoji_lark = EMOJI_LARK_PATH.read_text(encoding="utf-8") + cases: list[TestCase] = [] + items: dict[str, tuple[str, int | None]] = { + "emojis_valid_a": ("πŸ˜ƒπŸ˜‚πŸ˜ŠπŸ˜πŸ˜˜πŸ˜—πŸ˜™πŸ˜šπŸ˜‹πŸ˜›πŸ˜œπŸ˜πŸ€‘πŸ€—πŸ€”πŸ€πŸ˜πŸ˜‘πŸ˜ΆπŸ˜¬", None), + "emojis_valid_b": ("πŸ˜ƒπŸ˜ƒπŸ˜ƒ", None), + "emojis_invalid_text": ("πŸ˜ƒsmile", _count_prefix_tokens(tokenizer, "πŸ˜ƒsmile", "πŸ˜ƒ")), + "emojis_invalid_space": ("πŸ˜ƒ ", _count_prefix_tokens(tokenizer, "πŸ˜ƒ ", "πŸ˜ƒ")), + } + for case_name, (text, should_fail_on) in items.items(): + tokens = tokenizer.encode(text, bos=False, eos=False) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + case_name=case_name, + should_fail_on=should_fail_on, + mode="auto", + raw_lark=emoji_lark, + ) + ) + return cases + + +def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + items: list[tuple[str, list[ToolCall], dict[Mode, int | None]]] = [ + ( + "single_fcall", + [ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}'))], + {"auto": None, "none": 0}, + ), + ( + "multi_fcall", + [ + ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}')), + ToolCall(function=FunctionCall(name="hello_1", arguments='{"arg1": "val1", "arg2": "val2"}')), + ToolCall(function=FunctionCall(name="hello_2_3", arguments='{"arg1": "val1", "arg2": "val2"}')), + ], + {"auto": None, "none": 0}, + ), + ( + "emoji_fcall", + [ + ToolCall( + function=FunctionCall(name="he🧦🧦o", arguments='{"arg1": "🐱", "arg2": "🐢", "arg🧦": "🧦"}'), + ) + ], + {"auto": None, "none": 0}, + ), + ( + "pretty_printed_args", + [ + ToolCall( + function=FunctionCall( + name="hello", + arguments='{\n "arg1": "val1",\n "arg2": "val2"\n }\n', + ), + ) + ], + {"auto": None, "none": 0}, + ), + ( + "japanese_fcall", + [ToolCall(function=FunctionCall(name="こんにけは", arguments='{"こん": "にけは"}'))], + {"auto": None, "none": 0}, + ), + ] + for case_name, content, valid_for in items: + tokens = _encode_content(instruct_tokenizer, content) + for mode, should_fail_on in valid_for.items(): + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=should_fail_on, + case_name=case_name, + mode=mode, + ) + ) + + if tokenizer.version < TokenizerVersion.v13: + reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "none": 0} + else: + reasoning_valid_for = {"auto": None, "none": 0} + for mode, should_fail_on in reasoning_valid_for.items(): + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=should_fail_on, + case_name=f"{case_name}_reasoning", + mode=mode, + reasoning=True, + ) + ) + + # Broken / missing args edge cases + token_items: list[tuple[str, list[int], dict[Mode, int | None]]] = [ + ( + "fcall_broken_args", + [ + tokenizer.get_special_token("[TOOL_CALLS]"), + *tokenizer.encode("hello", bos=False, eos=False), + tokenizer.get_special_token("[ARGS]"), + *tokenizer.encode('{"a', bos=False, eos=False), + tokenizer.eos_id, + ], + {"auto": -1, "none": 0}, + ), + ( + "fcall_missing_args", + [ + tokenizer.get_special_token("[TOOL_CALLS]"), + *tokenizer.encode("hello", bos=False, eos=False), + tokenizer.get_special_token("[ARGS]"), + tokenizer.get_special_token("[TOOL_CALLS]"), + ], + {"auto": -1, "none": 0}, + ), + ] + + for case_name, tokens, valid_for in token_items: + for mode, should_fail_on in valid_for.items(): + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=should_fail_on, + case_name=case_name, + mode=mode, + ) + ) + return cases + + +def _generate_cases_text_and_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + content = [ + TextChunk(text="Hello!"), + ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}')), + ] + tokens = _encode_content(instruct_tokenizer, content) + text_len = len(tokenizer.encode("Hello!", bos=False, eos=False)) + + valid_for: dict[Mode, int | None] = {"auto": None, "none": text_len} + + for mode, should_fail_on in valid_for.items(): + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=should_fail_on, + case_name="text_fcall", + mode=mode, + ) + ) + + if tokenizer.version < TokenizerVersion.v13: + reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "none": 0} + else: + reasoning_valid_for = {"auto": None, "none": text_len} + for mode, should_fail_on in reasoning_valid_for.items(): + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=should_fail_on, + case_name="text_fcall_reasoning", + mode=mode, + reasoning=True, + ) + ) + + return cases + + +def _generate_cases_thinking_v11(mistral_tokenizer: MistralTokenizer) -> list[TestCase]: + r"""Generate thinking test cases for v11 (plain text think grammar).""" + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + # Note: "any"/"required" modes require tools, so only "auto"/"none" are used here (tools=None). + thinks: list[tuple[str, list[Any], dict[Mode, int | None]]] = [ + ( + "force_think", + [TextChunk(text="Hello world!")], + {"auto": 0, "none": 0}, + ), + ( + "think_without_response", + [TextChunk(text="Hello!")], + {"auto": -1, "none": -1}, + ), + ( + "unclosed_think", + [TextChunk(text="Hello!")], + {"auto": -1, "none": -1}, + ), + ( + "plain_think_with_response", + [TextChunk(text="Hello!World!")], + { + "auto": None, + "none": None, + }, + ), + ( + "think_with_tool_call", + [ + TextChunk(text="Hello!"), + ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}')), + ], + { + "auto": None, + "none": len(tokenizer.encode("Hello!", bos=False, eos=False)), + }, + ), + ( + "think_with_text_and_tool_call", + [ + TextChunk(text="Hello!Ho!"), + ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}')), + ], + { + # auto: think (content | fcalls) β€” picks content for "Ho!", then tool call rejected + "auto": len(tokenizer.encode("Hello!Ho!", bos=False, eos=False)), + "none": len(tokenizer.encode("Hello!Ho!", bos=False, eos=False)), + }, + ), + ] + for case_name, content, valid_for in thinks: + tokens = _encode_content(instruct_tokenizer, content) + for mode, should_fail_on in valid_for.items(): + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=should_fail_on, + case_name=case_name, + mode=mode, + reasoning=True, + ) + ) + return cases + + +def _generate_cases_thinking(mistral_tokenizer: MistralTokenizer) -> list[TestCase]: + r"""Generate thinking test cases for v13+ (structured think grammar with [THINK]/[/THINK] tokens).""" + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + + def _think_tokens(text: str) -> list[int]: + r"""Helper to encode a ThinkChunk.""" + assert isinstance(instruct_tokenizer, InstructTokenizerV13) + return instruct_tokenizer.encode_think(ThinkChunk(thinking=text)) + + # Note: "any"/"required" modes require tools, so only "auto"/"none" are used here (tools=None). + thinks: list[tuple[str, list[Any], dict[Mode, int | None]]] = [ + ( + "plain_text", + [TextChunk(text="Hello world!")], + {"auto": None, "none": None}, + ), + ( + "plain_think", + [ThinkChunk(thinking="Hello!")], + {"auto": -1, "none": -1}, + ), + ( + "plain_think_with_response", + [ThinkChunk(thinking="Hello!"), TextChunk(text="World!")], + {"auto": None, "none": None}, + ), + ( + "think_with_tool_call", + [ + ThinkChunk(thinking="Hello!"), + ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}')), + ], + { + "auto": None, + "none": len(_think_tokens("Hello!")), + }, + ), + ( + "think_text_tool_call", + [ + ThinkChunk(thinking="Hello!"), + TextChunk(text="World!"), + ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}')), + ], + { + "auto": None, + "none": len(_think_tokens("Hello!")) + len(tokenizer.encode("World!", bos=False, eos=False)), + }, + ), + ] + for case_name, content, valid_for in thinks: + tokens = _encode_content(instruct_tokenizer, content) + for mode, should_fail_on in valid_for.items(): + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=should_fail_on, + case_name=case_name, + mode=mode, + reasoning=True, + ) + ) + return cases + + +def _generate_single_tool_call(mistral_tokenizer: MistralTokenizer) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + single_call = [ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}'))] + single_tokens = _encode_content(instruct_tokenizer, single_call) + # Only "auto" can be used without tools; "any"/"required" require tools. + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=single_tokens, + should_fail_on=None, + case_name="single_tool_call", + mode="auto", + parallel_tool_calls=False, + ) + ) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=single_tokens, + should_fail_on=0, + case_name="single_tool_call", + mode="none", + parallel_tool_calls=False, + ) + ) + + if tokenizer.version < TokenizerVersion.v13: + reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "none": 0} + else: + reasoning_valid_for = {"auto": None, "none": 0} + for mode, should_fail_on in reasoning_valid_for.items(): + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=single_tokens, + should_fail_on=should_fail_on, + case_name="single_tool_call_reasoning", + mode=mode, + parallel_tool_calls=False, + reasoning=True, + ) + ) + + multi_calls = [ + ToolCall(function=FunctionCall(name="fn1", arguments='{"arg1": "val1", "arg2": "val2"}')), + ToolCall(function=FunctionCall(name="fn2", arguments='{"arg1": "val1", "arg2": "val2"}')), + ] + multi_tokens = _encode_content(instruct_tokenizer, multi_calls) + single_tokens_with_eos = _encode_content(instruct_tokenizer, [multi_calls[0]]) + fail_idx = len(single_tokens_with_eos) - 1 + + # Only "auto" can be used without tools; "any"/"required" require tools. + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=multi_tokens, + should_fail_on=fail_idx, + case_name="multi_tool_call_disallowed", + mode="auto", + parallel_tool_calls=False, + ) + ) + + return cases + + +def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: GrammarFactory) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + tools_strict = [ToolProvider.retrieve_payment_date(strict=True)] + + # 1. Non-strict tools β€” any function name/args accepted + non_strict_call = [ToolCall(function=FunctionCall(name="fn1", arguments='{"arg1": "val1", "arg2": "val2"}'))] + non_strict_tokens = _encode_content(instruct_tokenizer, non_strict_call) + for mode in _AUTO_ANY_REQUIRED: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=non_strict_tokens, + should_fail_on=None, + case_name="single_non_strict_tool_call", + mode=mode, + tools=[ToolProvider.retrieve_payment_date(strict=False)], + ) + ) + + # 2. Correct strict tool call + strict_call = [ + ToolCall(function=FunctionCall(name="retrieve_payment_date", arguments='{"transaction_id": "12345"}')) + ] + strict_tokens = _encode_content(instruct_tokenizer, strict_call) + for mode in _AUTO_ANY_REQUIRED: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=strict_tokens, + should_fail_on=None, + case_name="single_strict_tool_call", + mode=mode, + tools=[ToolProvider.retrieve_payment_date(strict=True)], + ) + ) + + # 3. Wrong args for strict tool β€” it must fail somewhere before the end. + wrong_args_call = [ToolCall(function=FunctionCall(name="retrieve_payment_date", arguments='{"bogus": "12345"}'))] + wrong_args_tokens = _encode_content(instruct_tokenizer, wrong_args_call) + bogus_start = _find_first_rejection(factory, wrong_args_tokens, mode=ToolChoiceEnum.auto, tools=tools_strict) + for mode in _AUTO_ANY_REQUIRED: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=wrong_args_tokens, + should_fail_on=bogus_start, + case_name="strict_tool_call_wrong_args", + mode=mode, + tools=tools_strict, + ) + ) + + # 4. Wrong name for strict tool + wrong_name_call = [ToolCall(function=FunctionCall(name="fn1", arguments='{"transaction_id": "12345"}'))] + wrong_name_tokens = _encode_content(instruct_tokenizer, wrong_name_call) + fail_on_name = _find_first_rejection(factory, wrong_name_tokens, mode=ToolChoiceEnum.auto, tools=tools_strict) + for mode in _AUTO_ANY_REQUIRED: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=wrong_name_tokens, + should_fail_on=fail_on_name, + case_name="strict_tool_call_wrong_name", + mode=mode, + tools=tools_strict, + ) + ) + + # 5. Multiple strict tool calls (both correct) + multi_strict = [ + ToolCall(function=FunctionCall(name="retrieve_payment_date", arguments='{"transaction_id": "12345"}')), + ToolCall(function=FunctionCall(name="retrieve_payment_status", arguments='{"transaction_id": "12345"}')), + ] + multi_strict_tokens = _encode_content(instruct_tokenizer, multi_strict) + for mode in _AUTO_ANY_REQUIRED: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=multi_strict_tokens, + should_fail_on=None, + case_name="multiple_strict_tool_calls", + mode=mode, + tools=[ + ToolProvider.retrieve_payment_date(strict=True), + ToolProvider.retrieve_payment_status(strict=True), + ], + ) + ) + + # 6. reasoning=True variants + if tokenizer.version < TokenizerVersion.v13: + for mode in _AUTO_ANY_REQUIRED: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=strict_tokens, + should_fail_on=0, + case_name="strict_tool_call_reasoning", + mode=mode, + tools=tools_strict, + reasoning=True, + ) + ) + else: + for mode in _AUTO_ANY_REQUIRED: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=strict_tokens, + should_fail_on=None, + case_name="strict_tool_call_reasoning", + mode=mode, + tools=tools_strict, + reasoning=True, + ) + ) + for mode in _AUTO_ANY_REQUIRED: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=multi_strict_tokens, + should_fail_on=None, + case_name="multiple_strict_tool_calls_reasoning", + mode=mode, + tools=[ + ToolProvider.retrieve_payment_date(strict=True), + ToolProvider.retrieve_payment_status(strict=True), + ], + reasoning=True, + ) + ) + + return cases + + +def _generate_named_tool_choice(mistral_tokenizer: MistralTokenizer, factory: GrammarFactory) -> list[TestCase]: + r"""Generate test cases for NamedToolChoice (forcing a specific tool).""" + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + + tools = [ + ToolProvider.retrieve_payment_date(strict=False), + ToolProvider.retrieve_payment_status(strict=False), + ] + + # 1. NamedToolChoice for retrieve_payment_date with correct function call + named_tool_date = NamedToolChoice( + type=ToolTypes.function, + function=FunctionName(name="retrieve_payment_date"), + ) + correct_date_call = [ + ToolCall( + function=FunctionCall( + name="retrieve_payment_date", + arguments='{"transaction_id": "12345"}', + ) + ) + ] + correct_date_tokens = _encode_content(instruct_tokenizer, correct_date_call) + + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=correct_date_tokens, + should_fail_on=None, + case_name="named_tool_choice_correct", + mode=named_tool_date, + tools=tools, + parallel_tool_calls=True, + ) + ) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=correct_date_tokens, + should_fail_on=None, + case_name="named_tool_choice_correct", + mode=named_tool_date, + tools=tools, + parallel_tool_calls=False, + ) + ) + + # 2. Non-strict NamedToolChoice should NOT enforce JSON arguments schema β€” + arbitrary_args_call = [ + ToolCall( + function=FunctionCall( + name="retrieve_payment_date", + arguments='{"completely": "arbitrary", "keys": [1, 2, 3]}', + ) + ) + ] + arbitrary_args_tokens = _encode_content(instruct_tokenizer, arbitrary_args_call) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=arbitrary_args_tokens, + should_fail_on=None, + case_name="named_tool_choice_non_strict_arbitrary_args", + mode=named_tool_date, + tools=tools, + parallel_tool_calls=True, + ) + ) + + # 3. NamedToolChoice should reject a different tool name + wrong_tool_call = [ + ToolCall( + function=FunctionCall( + name="retrieve_payment_status", + arguments='{"transaction_id": "12345"}', + ) + ) + ] + wrong_tool_tokens = _encode_content(instruct_tokenizer, wrong_tool_call) + + fail_idx = _find_first_rejection( + factory, + wrong_tool_tokens, + mode=named_tool_date, + tools=tools, + ) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=wrong_tool_tokens, + should_fail_on=fail_idx, + case_name="named_tool_choice_wrong_tool", + mode=named_tool_date, + tools=tools, + parallel_tool_calls=True, + ) + ) + + # 4. NamedToolChoice with reasoning mode + if tokenizer.version < TokenizerVersion.v13: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=correct_date_tokens, + should_fail_on=0, + case_name="named_tool_choice_reasoning", + mode=named_tool_date, + tools=tools, + reasoning=True, + ) + ) + else: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=correct_date_tokens, + should_fail_on=None, + case_name="named_tool_choice_reasoning", + mode=named_tool_date, + tools=tools, + reasoning=True, + ) + ) + + # 5. NamedToolChoice with strict tool should validate arguments + strict_tools = [ToolProvider.retrieve_payment_date(strict=True)] + named_tool_strict = NamedToolChoice( + type=ToolTypes.function, + function=FunctionName(name="retrieve_payment_date"), + ) + + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=correct_date_tokens, + should_fail_on=None, + case_name="named_tool_choice_strict_correct", + mode=named_tool_strict, + tools=strict_tools, + ) + ) + + wrong_args_call = [ + ToolCall( + function=FunctionCall( + name="retrieve_payment_date", + arguments='{"wrong_arg": "12345"}', + ) + ) + ] + wrong_args_tokens = _encode_content(instruct_tokenizer, wrong_args_call) + fail_on_args = _find_first_rejection( + factory, + wrong_args_tokens, + mode=named_tool_strict, + tools=strict_tools, + ) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=wrong_args_tokens, + should_fail_on=fail_on_args, + case_name="named_tool_choice_strict_wrong_args", + mode=named_tool_strict, + tools=strict_tools, + ) + ) + + return cases + + +def _find_first_json_schema_rejection( + factory: GrammarFactory, + tokens: list[int], + json_schema: dict[str, Any], +) -> int: + r"""Finds the index of the first token rejected by the JSON schema grammar. + + Args: + factory: The grammar factory. + tokens: The token sequence to test. + json_schema: The JSON schema to validate against. + + Returns: + The index of the first rejected token. + + Raises: + ValueError: If all tokens are accepted. + """ + template = factory.select_jinja_template(reasoning=False) + grammar = factory.get_lark_for_json_schema(template=template, json_schema=json_schema) + matcher = llg.LLMatcher(factory.llg_tokenizer, grammar) + for i, token in enumerate(tokens): + if not matcher.consume_token(token): + return i + raise ValueError("All tokens were accepted β€” expected a rejection") + + +def _generate_json_schema(mistral_tokenizer: MistralTokenizer, factory: GrammarFactory) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + items: list[tuple[str, str, int | None, dict[str, Any]]] = [ + ( + "basic_person_valid", + '{"name": "John", "age": 30}', + None, + SchemaProvider.basic_person(), + ), + ( + "invalid_json_missing_curly_bracket", + '"name": "John", "age": 30}', + 0, + {"type": "object"}, + ), + ( + "valid_json_white_spaces", + '\n {"name": "John", "age": 30}', + None, + {"type": "object"}, + ), + ( + "invalid_json_backslash_f", + '\f{"name": "John", "age": 30}', + 0, + {"type": "object"}, + ), + ( + "basic_person_non_strict_valid", + '{"age": "John", "name": 30}', + None, + {"type": "object"}, + ), + ( + "domerge_valid", + '{"new_clusters": {"b": ["a", "b", "c"], "d": ["e"]} }', + None, + SchemaProvider.basic_dict_of_list(), + ), + ] + + for case_name, text, should_fail_on, json_schema in items: + tokens = _encode_content(instruct_tokenizer, text) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=should_fail_on, + case_name=case_name, + mode="auto", + json_schema=json_schema, + ) + ) + + # Cases where should_fail_on must be computed dynamically because the exact + # rejection index depends on the grammar engine's internal byte-level parsing. + person_invalid_tokens = _encode_content(instruct_tokenizer, '{"age": "John", "name": 30}') + person_schema = SchemaProvider.basic_person() + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=person_invalid_tokens, + should_fail_on=_find_first_json_schema_rejection(factory, person_invalid_tokens, person_schema), + case_name="basic_person_invalid", + mode="auto", + json_schema=person_schema, + ) + ) + + return cases + + +def _generate_json_schema_reasoning(mistral_tokenizer: MistralTokenizer, factory: GrammarFactory) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + + # Valid JSON should be accepted even under reasoning templates with json_only + items: list[tuple[str, str, int | None, dict[str, Any]]] = [ + ( + "json_schema_reasoning_valid", + '{"name": "John", "age": 30}', + None, + SchemaProvider.basic_person(), + ), + ( + "json_schema_reasoning_whitespace_valid", + '\n {"name": "John", "age": 30}', + None, + {"type": "object"}, + ), + ( + "json_schema_reasoning_invalid_bracket", + '"name": "John", "age": 30}', + 0, + {"type": "object"}, + ), + ] + for case_name, text, should_fail_on, json_schema in items: + tokens = _encode_content(instruct_tokenizer, text) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=should_fail_on, + case_name=case_name, + mode="auto", + json_schema=json_schema, + reasoning=True, + ) + ) + + return cases + + +def _generate_json_only_negative(mistral_tokenizer: MistralTokenizer, factory: GrammarFactory) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + + # Plain text should be rejected by json_only grammar + text_tokens = _encode_content(instruct_tokenizer, "Hello world!") + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=text_tokens, + should_fail_on=0, + case_name="json_only_rejects_text", + mode="auto", + json_schema={"type": "object"}, + reasoning=False, + ) + ) + + # Tool calls should be rejected by json_only grammar + tool_call_tokens = _encode_content( + instruct_tokenizer, + [ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1"}'))], + ) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tool_call_tokens, + should_fail_on=0, + case_name="json_only_rejects_tool_call", + mode="auto", + json_schema={"type": "object"}, + reasoning=False, + ) + ) + + # Plain text should also be rejected with reasoning=True + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=text_tokens, + should_fail_on=0, + case_name="json_only_reasoning_rejects_text", + mode="auto", + json_schema={"type": "object"}, + reasoning=True, + ) + ) + + return cases + + +def _generate_json_schema_think_with_json( + mistral_tokenizer: MistralTokenizer, factory: GrammarFactory +) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + assert isinstance(instruct_tokenizer, InstructTokenizerV13) + + cases: list[TestCase] = [] + + json_schema = SchemaProvider.basic_person() + valid_json = '{"name": "John", "age": 30}' + + def _think_tokens(text: str) -> list[int]: + return instruct_tokenizer.encode_think(ThinkChunk(thinking=text)) + + # Think + valid JSON should be accepted with json_only + reasoning + think_json_tokens = [ + *_think_tokens("Let me think about this..."), + *tokenizer.encode(valid_json, bos=False, eos=False), + tokenizer.eos_id, + ] + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=think_json_tokens, + should_fail_on=None, + case_name="think_with_json_valid", + mode="auto", + json_schema=json_schema, + reasoning=True, + ) + ) + + json_only_tokens = _encode_content(instruct_tokenizer, valid_json) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=json_only_tokens, + should_fail_on=None, + case_name="think_with_json_no_think_valid", + mode="auto", + json_schema=json_schema, + reasoning=True, + ) + ) + + think_text_tokens = [ + *_think_tokens("Let me think..."), + *tokenizer.encode("Hello world!", bos=False, eos=False), + tokenizer.eos_id, + ] + + fail_idx = len(_think_tokens("Let me think...")) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=think_text_tokens, + should_fail_on=fail_idx, + case_name="think_with_json_rejects_text_after_think", + mode="auto", + json_schema=json_schema, + reasoning=True, + ) + ) + + return cases + + +def _generate_cases(mistral_tokenizer: MistralTokenizer, factory: GrammarFactory) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + tokenizer_version = instruct_tokenizer.tokenizer.version + + cases = _generate_general_cases(mistral_tokenizer) + cases += _generate_emoji_cases(mistral_tokenizer) + cases += _generate_json_schema(mistral_tokenizer, factory) + cases += _generate_json_schema_reasoning(mistral_tokenizer, factory) + cases += _generate_json_only_negative(mistral_tokenizer, factory) + cases += _generate_cases_tool_calls(mistral_tokenizer) + cases += _generate_single_tool_call(mistral_tokenizer) + cases += _generate_strict_tool_calls(mistral_tokenizer, factory) + cases += _generate_named_tool_choice(mistral_tokenizer, factory) + cases += _generate_cases_text_and_tool_calls(mistral_tokenizer) + + if tokenizer_version >= TokenizerVersion.v13: + cases += _generate_cases_thinking(mistral_tokenizer) + else: + cases += _generate_cases_thinking_v11(mistral_tokenizer) + + # v15+ supports think_with_json (thinking before JSON in json_only mode) + if tokenizer_version.supports_model_settings: + cases += _generate_json_schema_think_with_json(mistral_tokenizer, factory) + + return cases + + +_grammar_factories: dict[int, GrammarFactory] = {} + + +def _get_grammar_factory(mistral_tokenizer: MistralTokenizer) -> GrammarFactory: + tok_id = id(mistral_tokenizer) + if tok_id not in _grammar_factories: + _grammar_factories[tok_id] = GrammarFactory(mistral_tokenizer) + return _grammar_factories[tok_id] + + +_ALL_TOKENIZERS: list[MistralTokenizer] = [ + _build_tekken_mistral_tokenizer(TokenizerVersion.v11), + _build_tekken_mistral_tokenizer(TokenizerVersion.v13, add_think=True), + _build_tekken_mistral_tokenizer( + TokenizerVersion.v15, add_think=True, model_settings_builder=_V15_MODEL_SETTINGS_BUILDER + ), +] + +_ALL_CASES: list[TestCase] = [] +_ALL_MISTRAL_TOKENIZERS: dict[int, MistralTokenizer] = {} +for mistral_tokenizer in _ALL_TOKENIZERS: + tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer + _ALL_MISTRAL_TOKENIZERS[id(tokenizer)] = mistral_tokenizer + factory = _get_grammar_factory(mistral_tokenizer) + _ALL_CASES.extend(_generate_cases(mistral_tokenizer, factory)) + + +class TestGrammarFactory: + @pytest.mark.parametrize("test_case", _ALL_CASES, ids=lambda tc: tc.name) + def test_grammar(self, test_case: TestCase) -> None: + mistral_tokenizer = _ALL_MISTRAL_TOKENIZERS[id(test_case.tokenizer)] + factory = _get_grammar_factory(mistral_tokenizer) + + if test_case.raw_lark is not None: + grammar = test_case.raw_lark + elif test_case.json_schema is not None and test_case.tools is None: + template = factory.select_jinja_template(reasoning=test_case.reasoning) + grammar = factory.get_lark_for_json_schema(template=template, json_schema=test_case.json_schema) + else: + template = factory.select_jinja_template(reasoning=test_case.reasoning) + resolved_mode: ToolChoice + if isinstance(test_case.mode, NamedToolChoice): + resolved_mode = test_case.mode + else: + resolved_mode = ToolChoiceEnum(test_case.mode) + grammar = factory.get_lark_from_jinja( + template=template, + mode=resolved_mode, + tools=test_case.tools, + json_schema=test_case.json_schema, + parallel_tool_calls=test_case.parallel_tool_calls, + ) + + matcher = llg.LLMatcher(factory.llg_tokenizer, grammar) + + assert is_tekkenizer(test_case.tokenizer) + debug_tokens = [_token_debug_repr(test_case.tokenizer, t) for t in test_case.tokens] + for i, token in enumerate(test_case.tokens): + debug_bytes = _token_debug_repr(test_case.tokenizer, token) + if token != test_case.tokenizer.eos_id: + assert not matcher.is_stopped(), ( + f"Matcher stopped before token {i} id={token}\n\n" + f"Grammar:\n{grammar}\n\nTokens: {debug_tokens}\n\nIds: {test_case.tokens}" + ) + accepted = matcher.consume_token(token) + if i == test_case.should_fail_on: + if accepted: + raise AssertionError( + f"Token {token}={debug_bytes} at pos {i} was accepted but should have been rejected." + f"\n\nGrammar:\n{grammar}\n\nTokens: {debug_tokens}\n\nIds: {test_case.tokens}" + ) + break + elif not accepted: + raise AssertionError( + f"Token {token}={debug_bytes} at pos {i} was rejected but should have been accepted." + f"\n\nGrammar:\n{grammar}\n\nTokens: {debug_tokens}\n\nIds: {test_case.tokens}" + ) + + # For fully accepted sequences, verify the matcher reached a valid terminal state. + # Raw lark grammars (e.g., emoji matcher) don't consume EOS, so they may not reach + # a stopped state β€” skip the check for those. + if test_case.should_fail_on is None and test_case.raw_lark is None: + assert matcher.is_stopped(), ( + f"Matcher did not reach terminal state after consuming all tokens.\n\n" + f"Grammar:\n{grammar}\n\nTokens: {debug_tokens}\n\nIds: {test_case.tokens}" + ) + + @pytest.mark.parametrize( + ("tokenizer", "expected"), + [ + (MistralTokenizer.v1(), False), + (MistralTokenizer.v3(is_tekken=True), False), + (_build_tekken_mistral_tokenizer(TokenizerVersion.v11), True), + (_build_tekken_mistral_tokenizer(TokenizerVersion.v13, add_think=True), True), + ( + _build_tekken_mistral_tokenizer( + TokenizerVersion.v15, add_think=True, model_settings_builder=_V15_MODEL_SETTINGS_BUILDER + ), + True, + ), + ], + ) + def test_grammar_factory_is_supported(self, tokenizer: MistralTokenizer, expected: bool) -> None: + assert GrammarFactory.is_supported(tokenizer) is expected + + @pytest.mark.parametrize( + "tokenizer", + [MistralTokenizer.v1(), MistralTokenizer.v3(is_tekken=True)], + ) + def test_grammar_factory_init_rejects_unsupported(self, tokenizer: MistralTokenizer) -> None: + with pytest.raises(ValueError, match="Guidance requires a Tekken tokenizer with version >= v11"): + GrammarFactory(tokenizer) + + @pytest.mark.parametrize("mode", [ToolChoiceEnum.any, ToolChoiceEnum.required]) + def test_get_lark_rejects_any_required_without_tools( + self, v11_tekken: MistralTokenizer, mode: ToolChoiceEnum + ) -> None: + factory = GrammarFactory(v11_tekken) + template = factory.select_jinja_template(reasoning=False) + with pytest.raises(ValueError, match="please ensure to pass tools"): + factory.get_lark_from_jinja( + template=template, mode=mode, tools=None, json_schema=None, parallel_tool_calls=True + ) + + def test_get_lark_rejects_named_tool_not_in_tools(self, v11_tekken: MistralTokenizer) -> None: + factory = GrammarFactory(v11_tekken) + template = factory.select_jinja_template(reasoning=False) + named = NamedToolChoice(function=FunctionName(name="non_existent")) + tools = [ToolProvider.retrieve_payment_date(strict=True)] + with pytest.raises(ValueError, match="no tools with this name"): + factory.get_lark_from_jinja( + template=template, mode=named, tools=tools, json_schema=None, parallel_tool_calls=True + ) + + @pytest.mark.parametrize("mode", [ToolChoiceEnum.any, ToolChoiceEnum.required]) + @pytest.mark.parametrize("tools", [None, []]) + def test_any_required_without_tools_raises( + self, v11_tekken: MistralTokenizer, mode: ToolChoiceEnum, tools: list[Tool] | None + ) -> None: + factory = GrammarFactory(v11_tekken) + template = factory.select_jinja_template(reasoning=False) + with pytest.raises(ValueError, match="please ensure to pass tools"): + factory.get_lark_from_jinja( + template=template, tools=tools, mode=mode, json_schema=None, parallel_tool_calls=True + ) + + @pytest.mark.parametrize("tools", [None, [], [Tool(function=Function(name="existing_tool", parameters={}))]]) + def test_named_wrong_tools_raises(self, v11_tekken: MistralTokenizer, tools: list[Tool] | None) -> None: + factory = GrammarFactory(v11_tekken) + template = factory.select_jinja_template(reasoning=False) + with pytest.raises(ValueError, match="no tools with this name"): + factory.get_lark_from_jinja( + template=template, + tools=tools, + mode=NamedToolChoice(function=FunctionName(name="non_existent_tool")), + json_schema=None, + parallel_tool_calls=True, + ) + + +def _stub_get_special_token_id(token_name: str) -> str: + r"""Returns a stub lark grammar token for testing.""" + return f"<[{token_name}]>" + + +class TestConvertToolCalls: + def test_none_mode(self) -> None: + result = _convert_tool_calls( + tools=None, + mode=ToolChoiceEnum.none, + parallel_tool_calls=False, + get_special_token_id=_stub_get_special_token_id, + ) + assert result == "" + + def test_none_mode_with_tools(self) -> None: + tools = [ToolProvider.retrieve_payment_date(strict=True)] + result = _convert_tool_calls( + tools=tools, + mode=ToolChoiceEnum.none, + parallel_tool_calls=True, + get_special_token_id=_stub_get_special_token_id, + ) + assert result == "" + + def test_auto_mode_no_tools(self) -> None: + result = _convert_tool_calls( + tools=None, + mode=ToolChoiceEnum.auto, + parallel_tool_calls=False, + get_special_token_id=_stub_get_special_token_id, + ) + assert "<[[TOOL_CALLS]]>" in result + assert "<[[ARGS]]>" in result + assert "/.+/" in result + assert not result.startswith("(") + assert not result.endswith(")+") + + def test_auto_mode_non_strict(self) -> None: + tools = [ToolProvider.retrieve_payment_date(strict=False)] + result = _convert_tool_calls( + tools=tools, + mode=ToolChoiceEnum.auto, + parallel_tool_calls=False, + get_special_token_id=_stub_get_special_token_id, + ) + assert "<[[TOOL_CALLS]]>" in result + assert "/.+/" in result + assert not result.startswith("(") + assert not result.endswith(")+") + + def test_auto_mode_strict(self) -> None: + tools = [ + ToolProvider.retrieve_payment_date(strict=True), + ToolProvider.retrieve_payment_status(strict=True), + ] + result = _convert_tool_calls( + tools=tools, + mode=ToolChoiceEnum.auto, + parallel_tool_calls=False, + get_special_token_id=_stub_get_special_token_id, + ) + assert '"retrieve_payment_date"' in result + assert '"retrieve_payment_status"' in result + assert " | " in result + assert result.count("(") >= 2 + assert not result.endswith(")+") + + def test_auto_mode_single_strict(self) -> None: + tools = [ToolProvider.retrieve_payment_date(strict=True)] + result = _convert_tool_calls( + tools=tools, + mode=ToolChoiceEnum.auto, + parallel_tool_calls=False, + get_special_token_id=_stub_get_special_token_id, + ) + assert '"retrieve_payment_date"' in result + assert " | " not in result + assert not result.startswith("(") + assert not result.endswith(")+") + + def test_named_tool_choice_non_strict(self) -> None: + named = NamedToolChoice(function=FunctionName(name="retrieve_payment_date")) + tools = [ToolProvider.retrieve_payment_date(strict=False)] + result = _convert_tool_calls( + tools=tools, + mode=named, + parallel_tool_calls=False, + get_special_token_id=_stub_get_special_token_id, + ) + assert '"retrieve_payment_date"' in result + assert "/.+/" not in result + assert not result.endswith(")+") + + def test_named_tool_choice_strict(self) -> None: + named = NamedToolChoice(function=FunctionName(name="retrieve_payment_date")) + tools = [ + ToolProvider.retrieve_payment_date(strict=True), + ToolProvider.retrieve_payment_status(strict=True), + ] + result = _convert_tool_calls( + tools=tools, + mode=named, + parallel_tool_calls=False, + get_special_token_id=_stub_get_special_token_id, + ) + assert '"retrieve_payment_date"' in result + assert '"retrieve_payment_status"' not in result + assert " | " not in result + assert not result.startswith("(") + assert not result.endswith(")+") + + def test_parallel_tool_calls(self) -> None: + result = _convert_tool_calls( + tools=None, + mode=ToolChoiceEnum.auto, + parallel_tool_calls=True, + get_special_token_id=_stub_get_special_token_id, + ) + assert result.startswith("(") and result.endswith(")+") + + def test_empty_params_strict_tool(self) -> None: + tool = Tool(function=Function(name="empty_fn", parameters={}, strict=True)) + result = _convert_tool_calls( + tools=[tool], + mode=ToolChoiceEnum.auto, + parallel_tool_calls=False, + get_special_token_id=_stub_get_special_token_id, + ) + assert '"additionalProperties": false' in result + assert '"properties": {}' in result + assert not result.startswith("(") + assert not result.endswith(")+") diff --git a/tests/guidance/test_tokenizer.py b/tests/guidance/test_tokenizer.py new file mode 100644 index 00000000..491444ec --- /dev/null +++ b/tests/guidance/test_tokenizer.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +import re +from unittest.mock import MagicMock + +import llguidance as llg +import pytest + +from mistral_common.guidance.tokenizer import MistralLLGTokenizer, from_mistral_tokenizer +from mistral_common.protocol.instruct.normalize import get_normalizer +from mistral_common.protocol.instruct.validator import ValidationMode, get_validator +from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, SpecialTokens, Tokenizer, TokenizerVersion +from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV7 +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer +from tests.test_tekken import get_special_tokens, quick_vocab + +_NUM_SPECIAL_TOKENS = 100 +_EXTRA_TOKENS = [b"a", b"b", b"c", b"f", b"de", b"he", b"llo"] + + +@pytest.fixture(scope="module") +def tekkenizer() -> Tekkenizer: + special_tokens = get_special_tokens(TokenizerVersion.v7) + return Tekkenizer( + quick_vocab(_EXTRA_TOKENS), + special_tokens=special_tokens, + pattern=r".+", + vocab_size=256 + _NUM_SPECIAL_TOKENS, + num_special_tokens=_NUM_SPECIAL_TOKENS, + version=TokenizerVersion.v7, + ) + + +@pytest.fixture(scope="module") +def llg_tokenizer(tekkenizer: Tekkenizer) -> MistralLLGTokenizer: + return MistralLLGTokenizer(tekkenizer) + + +@pytest.fixture(scope="module") +def mistral_tokenizer(tekkenizer: Tekkenizer) -> MistralTokenizer: + instruct_tokenizer = InstructTokenizerV7(tekkenizer) + normalizer = get_normalizer(tekkenizer.version, tekkenizer.model_settings_builder) + validator = get_validator(tekkenizer.version, mode=ValidationMode.test) + return MistralTokenizer(instruct_tokenizer, validator=validator, request_normalizer=normalizer) + + +@pytest.fixture(scope="module") +def ll_tokenizer(mistral_tokenizer: MistralTokenizer) -> llg.LLTokenizer: + return from_mistral_tokenizer(mistral_tokenizer) + + +class TestMistralLLGTokenizer: + def test_init_rejects_non_tekkenizer(self) -> None: + mock_tokenizer = MagicMock(spec=Tokenizer) + with pytest.raises(TypeError, match="Guidance only supports Tekken tokenizers"): + MistralLLGTokenizer(mock_tokenizer) + + def test_eos_and_bos_ids(self, tekkenizer: Tekkenizer, llg_tokenizer: MistralLLGTokenizer) -> None: + assert llg_tokenizer.eos_token_id == tekkenizer.eos_id + assert llg_tokenizer.bos_token_id == tekkenizer.bos_id + + def test_tokens_length(self, tekkenizer: Tekkenizer, llg_tokenizer: MistralLLGTokenizer) -> None: + assert len(llg_tokenizer.tokens) == tekkenizer.n_words + + def test_special_token_ids_count(self, tekkenizer: Tekkenizer, llg_tokenizer: MistralLLGTokenizer) -> None: + assert len(llg_tokenizer.special_token_ids) == tekkenizer.num_special_tokens + + def test_all_special_tokens_are_angle_bracketed(self, llg_tokenizer: MistralLLGTokenizer) -> None: + for i in llg_tokenizer.special_token_ids: + token_bytes = llg_tokenizer.tokens[i] + token_str = token_bytes.decode("utf-8") + assert re.fullmatch(r"<.*>", token_str), f"Special token at id={i} is not angle-bracketed: {token_str!r}" + + def test_special_token_conversion(self, tekkenizer: Tekkenizer, llg_tokenizer: MistralLLGTokenizer) -> None: + checked = 0 + for token in SpecialTokens: + try: + rank = tekkenizer.get_special_token(token.value) + except ValueError: + continue + + llg_bytes = llg_tokenizer.tokens[rank] + if re.fullmatch(r"\[.*\]", token.value): + expected = token.value.replace("[", "<").replace("]", ">") + else: + expected = token.value + + assert llg_bytes == expected.encode("utf-8"), ( + f"Token {token.name} at rank={rank}: expected {expected!r}, got {llg_bytes!r}" + ) + checked += 1 + + # Filler tokens () should be preserved as-is + for i in range(tekkenizer.num_special_tokens): + piece = tekkenizer.id_to_piece(i) + if re.fullmatch(r"", piece): + assert llg_tokenizer.tokens[i] == piece.encode("utf-8"), ( + f"Filler token at id={i}: expected {piece!r}, got {llg_tokenizer.tokens[i]!r}" + ) + checked += 1 + + assert checked == tekkenizer.num_special_tokens + + def test_non_special_tokens_match_byte_pieces( + self, tekkenizer: Tekkenizer, llg_tokenizer: MistralLLGTokenizer + ) -> None: + for i in range(tekkenizer.num_special_tokens, tekkenizer.n_words): + expected = tekkenizer.id_to_byte_piece(i, SpecialTokenPolicy.RAISE) + assert llg_tokenizer.tokens[i] == expected, ( + f"Token at id={i}: expected {expected!r}, got {llg_tokenizer.tokens[i]!r}" + ) + + def test_call_encodes_string(self, tekkenizer: Tekkenizer, llg_tokenizer: MistralLLGTokenizer) -> None: + test_string = "abc" + assert llg_tokenizer(test_string) == tekkenizer.encode(test_string, bos=False, eos=False) + + def test_init_rejects_invalid_special_token_format(self) -> None: + base = list(Tekkenizer.DEPRECATED_SPECIAL_TOKENS) + next_rank = len(base) + special_tokens: list[SpecialTokenInfo] = [ + *base, + SpecialTokenInfo(rank=next_rank, token_str="INVALID_NO_BRACKETS", is_control=True), + ] + vocab = quick_vocab() + num_special = len(special_tokens) + tekkenizer = Tekkenizer( + vocab, + special_tokens=special_tokens, + pattern=r".+", + vocab_size=len(vocab) + num_special, + num_special_tokens=num_special, + version=TokenizerVersion.v7, + ) + with pytest.raises(ValueError, match="Invalid special token"): + MistralLLGTokenizer(tekkenizer) + + def test_init_rejects_duplicate_special_tokens(self) -> None: + base = list(Tekkenizer.DEPRECATED_SPECIAL_TOKENS) + next_rank = len(base) + # Both [CUSTOM] and map to after bracket conversion + special_tokens: list[SpecialTokenInfo] = [ + *base, + SpecialTokenInfo(rank=next_rank, token_str="", is_control=True), + SpecialTokenInfo(rank=next_rank + 1, token_str="[CUSTOM]", is_control=True), + ] + vocab = quick_vocab() + num_special = len(special_tokens) + tekkenizer = Tekkenizer( + vocab, + special_tokens=special_tokens, + pattern=r".+", + vocab_size=len(vocab) + num_special, + num_special_tokens=num_special, + version=TokenizerVersion.v7, + ) + with pytest.raises(ValueError, match="Duplicate special token"): + MistralLLGTokenizer(tekkenizer) + + +class TestFromMistralTokenizer: + def test_properties(self, tekkenizer: Tekkenizer, ll_tokenizer: llg.LLTokenizer) -> None: + assert isinstance(ll_tokenizer, llg.LLTokenizer) + assert ll_tokenizer.vocab_size == tekkenizer.n_words + assert ll_tokenizer.eos_token == tekkenizer.eos_id + for i in range(tekkenizer.num_special_tokens): + assert ll_tokenizer.is_special_token(i), f"Token id={i} should be special" + + def test_tokenize_str_matches_tekkenizer(self, tekkenizer: Tekkenizer, ll_tokenizer: llg.LLTokenizer) -> None: + test_strings = ["abc", "hello", "de", "abcdefhello"] + for s in test_strings: + expected = tekkenizer.encode(s, bos=False, eos=False) + result = ll_tokenizer.tokenize_str(s) + assert result == expected, f"Mismatch for {s!r}: expected {expected}, got {result}" + + def test_decode_str_roundtrip(self, tekkenizer: Tekkenizer, ll_tokenizer: llg.LLTokenizer) -> None: + test_strings = ["abc", "hello", "de"] + for s in test_strings: + tokens = tekkenizer.encode(s, bos=False, eos=False) + decoded = ll_tokenizer.decode_str(tokens) + assert decoded == s, f"Roundtrip failed for {s!r}: got {decoded!r}" + + def test_decode_bytes_roundtrip(self, tekkenizer: Tekkenizer, ll_tokenizer: llg.LLTokenizer) -> None: + test_strings = ["abc", "hello"] + for s in test_strings: + tokens = tekkenizer.encode(s, bos=False, eos=False) + decoded = ll_tokenizer.decode_bytes(tokens) + assert decoded == s.encode("utf-8"), f"Roundtrip failed for {s!r}: got {decoded!r}" + + def test_decode_special_tokens(self, tekkenizer: Tekkenizer, ll_tokenizer: llg.LLTokenizer) -> None: + for token in SpecialTokens: + try: + rank = tekkenizer.get_special_token(token.value) + except ValueError: + continue + + if re.fullmatch(r"\[.*\]", token.value): + expected = token.value.replace("[", "<").replace("]", ">") + else: + expected = token.value + + decoded_str = ll_tokenizer.decode_str([rank]) + assert decoded_str == expected, ( + f"decode_str for {token.name} (id={rank}): expected {expected!r}, got {decoded_str!r}" + ) + decoded_bytes = ll_tokenizer.decode_bytes([rank]) + assert decoded_bytes == expected.encode("utf-8"), ( + f"decode_bytes for {token.name} (id={rank}): " + f"expected {expected.encode('utf-8')!r}, got {decoded_bytes!r}" + ) diff --git a/tests/test_imports.py b/tests/test_imports.py index 4ba25dad..df1de977 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -1,19 +1,23 @@ import builtins from functools import _lru_cache_wrapper from types import ModuleType -from typing import Any, Callable +from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest from mistral_common.imports import ( assert_hf_hub_installed, + assert_jinja2_installed, + assert_llguidance_installed, assert_opencv_installed, assert_package_installed, assert_sentencepiece_installed, assert_soundfile_installed, assert_soxr_installed, is_hf_hub_installed, + is_jinja2_installed, + is_llguidance_installed, is_opencv_installed, is_package_installed, is_sentencepiece_installed, @@ -23,6 +27,8 @@ _IS_INSTALLED_TO_TESTS = [ is_hf_hub_installed, + is_jinja2_installed, + is_llguidance_installed, is_sentencepiece_installed, is_soundfile_installed, is_soxr_installed, @@ -34,6 +40,16 @@ assert_hf_hub_installed, "`huggingface_hub` is not installed. Please install it with `pip install mistral-common[hf-hub]`", ), + ( + is_jinja2_installed, + assert_jinja2_installed, + "`jinja2` is not installed. Please install it with `pip install mistral-common[guidance]`", + ), + ( + is_llguidance_installed, + assert_llguidance_installed, + "`llguidance` is not installed. Please install it with `pip install mistral-common[guidance]`", + ), ( is_opencv_installed, assert_opencv_installed, @@ -117,16 +133,19 @@ def test_is_installed(mock_is_package_installed: MagicMock, is_installed_fn: _lr def test_assert_installed( mock_is_package_installed: MagicMock, is_installed_fn: _lru_cache_wrapper, - assert_fn: Callable[[], None], + assert_fn: _lru_cache_wrapper, error_message: str, ) -> None: is_installed_fn.cache_clear() + assert_fn.cache_clear() mock_is_package_installed.return_value = True assert_fn() is_installed_fn.cache_clear() + assert_fn.cache_clear() mock_is_package_installed.return_value = False with pytest.raises(ImportError) as exc_info: assert_fn() assert str(exc_info.value) == error_message is_installed_fn.cache_clear() + assert_fn.cache_clear() diff --git a/uv.lock b/uv.lock index ef30b985..8b84d3f5 100644 --- a/uv.lock +++ b/uv.lock @@ -759,6 +759,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/01/0e/b27cdbaccf30b890c40ed1da9fd4a3593a5cf94dae54fb34f8a4b74fcd3f/jsonschema_specifications-2025.4.1-py3-none-any.whl", hash = "sha256:4653bffbd6584f7de83a67e0d620ef16900b390ddc7939d56684d6c81e33f1af", size = 18437, upload-time = "2025-04-23T12:34:05.422Z" }, ] +[[package]] +name = "llguidance" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/48/3f7a9d3ff1b36bba92b5107a3a21286821227afe9ea464736133994d61fb/llguidance-1.3.0.tar.gz", hash = "sha256:861249afd51dc325646834462ea827e57a5c2b2042e108e6aae7059fdad9104d", size = 1070460, upload-time = "2025-10-20T19:58:44.164Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/33/be5acb85cd8cdc4afde33d9c234eece9f318e087920255af3c05864cd3e7/llguidance-1.3.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:f7685222660a762e481ac633d49cc559c64980fe2ee59c8f932a5bb5cbc0c2c2", size = 3220647, upload-time = "2025-10-20T19:58:42.542Z" }, + { url = "https://files.pythonhosted.org/packages/82/e6/b48bda5b15efeaeb62bd0dba8fc6a01d4ae5457a85dbb5d18632385fe15c/llguidance-1.3.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:098030ff0687261a3f1bd54cf21fe951fc861d56d37a0671250dd36677eaf224", size = 3099830, upload-time = "2025-10-20T19:58:40.826Z" }, + { url = "https://files.pythonhosted.org/packages/aa/11/44389d3d1526d7a5c38ffd587a5ebc61d7bee443ac1dea95f2089ad58f5f/llguidance-1.3.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f6caca5d78db7f76e1fbb0fff8607b861c32d47fa3d5dee2fc49de27ee269df", size = 2835242, upload-time = "2025-10-20T19:58:34.518Z" }, + { url = "https://files.pythonhosted.org/packages/e7/ca/53ea256396405e4dee70d5a4a35e18543408e18bb16b251d6ca6b5d80310/llguidance-1.3.0-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0612bb3f034d2487b6e8f9561f02a94a6039d88273bf0c5c539a3bd3895e47d2", size = 3297480, upload-time = "2025-10-20T19:58:37.033Z" }, + { url = "https://files.pythonhosted.org/packages/83/a8/1ff2bedb8f9acb46a2d2d603415d272bb622c142ea86f5b95445cc6e366c/llguidance-1.3.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc17e9dd602c3879bf91664a64bf72f54c74dbfbeb24ccfab6a5fe435b12f7aa", size = 3033133, upload-time = "2025-10-20T19:58:38.721Z" }, + { url = "https://files.pythonhosted.org/packages/d7/a7/9b8086c0cfdddf3f6d47b173a404fa7ac46272f7affbee082c36740f4f1c/llguidance-1.3.0-cp39-abi3-win32.whl", hash = "sha256:2f6f558485a43e273fc5c6c974a9a3ace5d5e170076db9b40e0560e41c3ff18f", size = 2598109, upload-time = "2025-10-20T19:58:47.656Z" }, + { url = "https://files.pythonhosted.org/packages/5a/7e/809349638231f469b9056c0e1bfd924d5ef5558b3b3ec72d093b6fad33b1/llguidance-1.3.0-cp39-abi3-win_amd64.whl", hash = "sha256:1d1cd1c8618d1a13605d3e057c978651e551c8c469b481ee4041f1d6c436002d", size = 2789946, upload-time = "2025-10-20T19:58:45.958Z" }, +] + [[package]] name = "markdown" version = "3.8.2" @@ -874,10 +889,27 @@ dependencies = [ ] [package.optional-dependencies] +all = [ + { name = "click" }, + { name = "fastapi", extra = ["standard"] }, + { name = "huggingface-hub" }, + { name = "jinja2" }, + { name = "llguidance" }, + { name = "opencv-python-headless" }, + { name = "pydantic-settings" }, + { name = "sentencepiece" }, + { name = "soundfile" }, + { name = "soxr" }, + { name = "uvloop", version = "0.22.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.14'" }, +] audio = [ { name = "soundfile" }, { name = "soxr" }, ] +guidance = [ + { name = "jinja2" }, + { name = "llguidance" }, +] hf-hub = [ { name = "huggingface-hub" }, ] @@ -934,8 +966,11 @@ requires-dist = [ { name = "click", marker = "extra == 'server'", specifier = ">=8.1.0" }, { name = "fastapi", extras = ["standard"], marker = "extra == 'server'", specifier = ">=0.118.3" }, { name = "huggingface-hub", marker = "extra == 'hf-hub'", specifier = ">=1.0" }, + { name = "jinja2", marker = "extra == 'guidance'", specifier = ">=3.1.0" }, { name = "jsonschema", specifier = ">=4.21.1" }, + { name = "llguidance", marker = "extra == 'guidance'", specifier = ">=1.3.0,<1.4.0" }, { name = "mistral-common", extras = ["opencv"], marker = "extra == 'image'" }, + { name = "mistral-common", extras = ["opencv", "sentencepiece", "audio", "image", "guidance", "hf-hub", "server"], marker = "extra == 'all'" }, { name = "mistral-common", extras = ["soundfile"], marker = "extra == 'audio'" }, { name = "mistral-common", extras = ["soxr"], marker = "extra == 'audio'" }, { name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.25,<2.4" }, @@ -955,7 +990,7 @@ requires-dist = [ { name = "typing-extensions", specifier = ">=4.11.0" }, { name = "uvloop", marker = "python_full_version >= '3.14' and extra == 'server'", specifier = ">=0.22.1" }, ] -provides-extras = ["opencv", "sentencepiece", "soundfile", "soxr", "audio", "image", "hf-hub", "server"] +provides-extras = ["opencv", "sentencepiece", "soundfile", "soxr", "audio", "image", "guidance", "hf-hub", "server", "all"] [package.metadata.requires-dev] dev = [