From 497c4c7ff56e357398cf2956a5188c9e5b373de7 Mon Sep 17 00:00:00 2001 From: Julien Denize Date: Fri, 20 Mar 2026 12:10:14 +0100 Subject: [PATCH 01/21] Add Mistral guidance --- pyproject.toml | 7 +- src/mistral_common/guidance/__init__.py | 0 .../guidance/data/base_grammar.lark.jinja | 19 + .../data/plain_text_think_grammar.lark.jinja | 25 + .../guidance/data/think_grammar.lark.jinja | 21 + .../guidance/grammar_factory.py | 186 +++ src/mistral_common/guidance/tokenizer.py | 104 ++ src/mistral_common/imports.py | 25 + .../tokens/tokenizers/sentencepiece.py | 6 + .../tokens/tokenizers/tekken.py | 7 +- tests/data/emoji.lark | 2 + tests/guidance/__init__.py | 0 tests/guidance/test_guidance.py | 1021 +++++++++++++++++ tests/guidance/test_tokenizer.py | 168 +++ tests/test_imports.py | 23 +- uv.lock | 37 +- 16 files changed, 1645 insertions(+), 6 deletions(-) create mode 100644 src/mistral_common/guidance/__init__.py create mode 100644 src/mistral_common/guidance/data/base_grammar.lark.jinja create mode 100644 src/mistral_common/guidance/data/plain_text_think_grammar.lark.jinja create mode 100644 src/mistral_common/guidance/data/think_grammar.lark.jinja create mode 100644 src/mistral_common/guidance/grammar_factory.py create mode 100644 src/mistral_common/guidance/tokenizer.py create mode 100644 tests/data/emoji.lark create mode 100644 tests/guidance/__init__.py create mode 100644 tests/guidance/test_guidance.py create mode 100644 tests/guidance/test_tokenizer.py diff --git a/pyproject.toml b/pyproject.toml index 9092b594..75176bc0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ soundfile = ["soundfile>=0.12.1"] soxr = ["soxr>=0.5.0"] audio = ["mistral_common[soundfile]", "mistral_common[soxr]"] image = ["mistral_common[opencv]"] - +guidance = ["llguidance>=1.3.0,<1.4.0", "jinja2>=3.1.0"] hf-hub = ["huggingface-hub>=1.0"] server = [ "fastapi[standard]>=0.118.3", @@ -41,6 +41,9 @@ server = [ "click>=8.1.0", "uvloop>=0.22.1; python_version >= '3.14'" ] +all = [ + "mistral_common[opencv,sentencepiece,audio,image,guidance,hf-hub,server]" +] [project.scripts] mistral_common = "mistral_common.experimental.app.main:cli" @@ -90,7 +93,7 @@ warn_unused_ignores = true exclude = ["docs", "tools", "build"] [[tool.mypy.overrides]] -module = ["sentencepiece.*", "cv2", "cv2.*","soxr", "soundfile"] +module = ["sentencepiece.*", "cv2", "cv2.*","soxr", "soundfile", "llguidance", "llguidance.*", "jinja2", "jinja2.*"] ignore_missing_imports = true [tool.pytest.ini_options] diff --git a/src/mistral_common/guidance/__init__.py b/src/mistral_common/guidance/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/mistral_common/guidance/data/base_grammar.lark.jinja b/src/mistral_common/guidance/data/base_grammar.lark.jinja new file mode 100644 index 00000000..a1940f4a --- /dev/null +++ b/src/mistral_common/guidance/data/base_grammar.lark.jinja @@ -0,0 +1,19 @@ +{% if json_schema_str != None -%} +start: body | SAFE_WS? %json {{ json_schema_str }} +{% else -%} +start: body +{% endif -%} + +{% if mode == "auto" -%} +body: content | (content? fcalls) +{% elif mode == "any" -%} +body: fcalls +{% elif mode == "none" -%} +body: content +{% endif -%} + +fcalls: {{ fcall }} + +content: (/(.|\n)+/)+ + +SAFE_WS: /[ \t\r\n]+/ diff --git a/src/mistral_common/guidance/data/plain_text_think_grammar.lark.jinja b/src/mistral_common/guidance/data/plain_text_think_grammar.lark.jinja new file mode 100644 index 00000000..86b7230a --- /dev/null +++ b/src/mistral_common/guidance/data/plain_text_think_grammar.lark.jinja @@ -0,0 +1,25 @@ +{% if json_schema_str != None -%} +start: body | %json {{ json_schema_str }} +{% else -%} +start: body +{% endif -%} + +{% if mode == "auto" -%} +body: think (content | fcalls) +{% elif mode == "any" -%} +body: think fcalls +{% elif mode == "none" -%} +body: think content +{% endif -%} + +fcalls: {{ fcall }} + +NO_THINK: /(.|\n)+/ & ~/(?s:.*)(<\/think>|)(?s:.*)/ +end_think[capture, lazy]: /(.|\n)*<\/think>/ + +text_first_optional: (NO_THINK)* +content: NO_THINK + +think: SAFE_WS? // text_first_optional end_think + +SAFE_WS: /[ \t\r\n]+/ diff --git a/src/mistral_common/guidance/data/think_grammar.lark.jinja b/src/mistral_common/guidance/data/think_grammar.lark.jinja new file mode 100644 index 00000000..59f16a71 --- /dev/null +++ b/src/mistral_common/guidance/data/think_grammar.lark.jinja @@ -0,0 +1,21 @@ +{% if json_schema_str != None -%} +start: body | %json {{ json_schema_str }} +{% else -%} +start: body +{% endif -%} + +{% if mode == "auto" -%} +body: think? (content | fcalls) +{% elif mode == "any" -%} +body: think? fcalls +{% elif mode == "none" -%} +body: think? content +{% endif -%} + +fcalls: content? {{ fcall }} + +content: (/(.|\n)+/)+ + +think: content + +SAFE_WS: /[ \t\r\n]+/ diff --git a/src/mistral_common/guidance/grammar_factory.py b/src/mistral_common/guidance/grammar_factory.py new file mode 100644 index 00000000..63621fab --- /dev/null +++ b/src/mistral_common/guidance/grammar_factory.py @@ -0,0 +1,186 @@ +import json +from enum import Enum +from pathlib import Path +from typing import Any + +from mistral_common.guidance.tokenizer import from_mistral_tokenizer +from mistral_common.imports import ( + assert_jinja2_installed, + assert_llguidance_installed, + is_jinja2_installed, + is_llguidance_installed, +) +from mistral_common.protocol.instruct.tool_calls import Tool, ToolChoice +from mistral_common.tokens.tokenizers.base import TokenizerVersion +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from mistral_common.tokens.tokenizers.tekken import is_tekkenizer + +if is_llguidance_installed(): + import llguidance as llg + +if is_jinja2_installed(): + from jinja2 import Template + +JINJA_DIR = Path(__file__).parent / "data" + + +class _GrammarVariant(str, Enum): + base = "base" + plain_think = "plain_think" + think = "think" + + +JINJA_PATHS = { + _GrammarVariant.base: JINJA_DIR / "base_grammar.lark.jinja", + _GrammarVariant.plain_think: JINJA_DIR / "plain_text_think_grammar.lark.jinja", + _GrammarVariant.think: JINJA_DIR / "think_grammar.lark.jinja", +} + + +def _get_tool_args_json(tool: Tool) -> dict[str, Any]: + r"""Returns the JSON schema for a tool's arguments.""" + args = tool.function.parameters if tool.function.strict else {"type": "object"} + if args == {}: + args = {"type": "object", "properties": {}, "additionalProperties": False} + return args + + +def convert_tool_calls( + tools: list[Tool] | None, + mode: ToolChoice, + parallel_tool_calls: bool, +) -> str: + r"""Converts tool calls to a lark grammar string. + + Args: + tools: The list of tools available. + mode: The tool choice mode. + parallel_tool_calls: Whether parallel tool calls are allowed. + + Returns: + The lark grammar string for tool calls. + """ + if mode == "none": + return "" + + any_strict_true = any(tool.function.strict for tool in tools) if tools else False + + if not tools or not any_strict_true: + single_tool_call = ' SAFE_WS? /.+/ SAFE_WS? %json {"type": "object"} SAFE_WS?' + return f"({single_tool_call})+" if parallel_tool_calls else single_tool_call + + individual_tool_calls = [] + for tool in tools: + args = _get_tool_args_json(tool) + individual_tool_calls.append( + f'( SAFE_WS? "{tool.function.name}" SAFE_WS? %json ' + f"{json.dumps(args, ensure_ascii=False)} SAFE_WS?)" + ) + + single_tool_call = f"{' | '.join(individual_tool_calls)}" + return f"({single_tool_call})+" if parallel_tool_calls else single_tool_call + + +class GrammarFactory: + r"""Generates grammars for a given tokenizer.""" + + @staticmethod + def is_supported(tokenizer: MistralTokenizer) -> bool: + r"""Checks whether the given tokenizer is supported by guidance. + + Guidance requires a Tekken tokenizer with version >= v11. + + Args: + tokenizer: The Mistral tokenizer to check. + + Returns: + Whether the tokenizer is supported. + """ + inner = tokenizer.instruct_tokenizer.tokenizer + return is_tekkenizer(inner) and not inner.version < TokenizerVersion.v11 + + def __init__(self, tokenizer: MistralTokenizer) -> None: + r"""Initialize the grammar factory. + + Args: + tokenizer: The Mistral tokenizer to generate grammars for. + + Raises: + ValueError: If the tokenizer is not supported (see + [`is_supported`][mistral_common.guidance.grammar_factory.GrammarFactory.is_supported]). + """ + assert_llguidance_installed() + assert_jinja2_installed() + self._tokenizer = tokenizer.instruct_tokenizer.tokenizer + if not self.is_supported(tokenizer): + raise ValueError( + f"Guidance requires a Tekken tokenizer with version >= v11, " + f"got {type(self._tokenizer).__name__} {self._tokenizer.version.value}" + ) + self._llg_tokenizer = from_mistral_tokenizer(tokenizer) + + def select_jinja_template(self, reasoning: bool) -> str: + r"""Selects and returns the appropriate jinja template content based on tokenizer version and reasoning mode. + + Args: + reasoning: Whether reasoning/thinking mode is enabled. + + Returns: + The jinja template content as a string. + """ + tokenizer_version = self._tokenizer.version + if tokenizer_version < TokenizerVersion.v13: + jinja_key = _GrammarVariant.plain_think if reasoning else _GrammarVariant.base + else: + jinja_key = _GrammarVariant.think if reasoning else _GrammarVariant.base + jinja_path = JINJA_PATHS[jinja_key] + return jinja_path.read_text(encoding="utf-8") + + def get_lark( + self, + reasoning: bool, + mode: ToolChoice, + tools: list[Tool] | None, + json_schema: dict[str, Any] | None, + parallel_tool_calls: bool, + ) -> str: + r"""Renders a lark grammar from a jinja template. + + Args: + reasoning: Whether reasoning/thinking mode is enabled. + mode: The function calling mode (auto, any, none). + tools: The list of tools available. + json_schema: JSON schema to additionally allow, unioned with the grammar. + parallel_tool_calls: Whether parallel tool calls are allowed. + + Returns: + The rendered lark grammar string. + """ + jinja_template = Template(self.select_jinja_template(reasoning=reasoning)) + lark_grammar = jinja_template.render( + mode=mode, + fcall=convert_tool_calls(tools, mode, parallel_tool_calls), + json_schema_str=json.dumps(json_schema, ensure_ascii=False) if json_schema else None, + ) + return lark_grammar + + def get_lark_for_json_schema(self, json_schema: dict[str, Any]) -> str: + r"""Returns a lark grammar that only accepts JSON objects matching the given schema.""" + return f"start: SAFE_WS? %json {json.dumps(json_schema, ensure_ascii=False)} \nSAFE_WS: /[ \t\r\n]+/" + + def get_matcher(self, lark: str) -> "llg.LLMatcher": + r"""Creates an LLMatcher from a lark grammar string. + + Args: + lark: The lark grammar string. + + Returns: + The LLMatcher instance. + + Raises: + ValueError: If the grammar is invalid. + """ + error = llg.LLMatcher.validate_grammar(lark) + if error: + raise ValueError(f"Invalid grammar: {error}") + return llg.LLMatcher(self._llg_tokenizer, lark) diff --git a/src/mistral_common/guidance/tokenizer.py b/src/mistral_common/guidance/tokenizer.py new file mode 100644 index 00000000..57cbdb93 --- /dev/null +++ b/src/mistral_common/guidance/tokenizer.py @@ -0,0 +1,104 @@ +import re +from typing import Any + +from mistral_common.imports import assert_llguidance_installed, is_llguidance_installed +from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, Tokenizer +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from mistral_common.tokens.tokenizers.tekken import is_tekkenizer + +if is_llguidance_installed(): + import llguidance as llg + + +class MistralLLGTokenizer: + r"""Wraps a Tekken tokenizer for use with llguidance. + + Attributes: + eos_token_id: The end of string token id. + bos_token_id: The beginning of string token id. + tokens: The list of token byte representations. + special_token_ids: The list of special token ids. + """ + + eos_token_id: int + bos_token_id: int + tokens: list[bytes] + special_token_ids: list[int] + + def __init__(self, tokenizer: Tokenizer) -> None: + r"""Initialize the wrapper. + + Args: + tokenizer: The Tekken tokenizer to wrap for llguidance compatibility. + + Raises: + TypeError: If the tokenizer is not a Tekkenizer. + ValueError: If a special token has an invalid format. + """ + assert_llguidance_installed() + if not is_tekkenizer(tokenizer): + raise TypeError(f"Guidance only supports Tekken tokenizers, got {type(tokenizer)}") + self._tokenizer = tokenizer + self.eos_token_id = self._tokenizer.eos_id + self.bos_token_id = self._tokenizer.bos_id + + self.tokens: list[bytes] = [] + self.special_token_ids: list[int] = [] + + seen_special_tokens: set[str] = set() + for i in range(self._tokenizer.n_words): + # Convert square brackets to angle brackets for special tokens, + # since llg only recognizes the latter. + if i < self._tokenizer.num_special_tokens: + token_rep = self._tokenizer.id_to_piece(i) + if match := re.fullmatch(r"\[(.*)\]", token_rep): + token_rep_llg = f"<{match.group(1)}>" + else: + token_rep_llg = token_rep + + if not re.fullmatch(r"<.*>", token_rep_llg): + raise ValueError(f"Invalid special token: {token_rep_llg} ({token_rep})") + if token_rep_llg in seen_special_tokens: + raise ValueError(f"Duplicate special token: {token_rep_llg} (already seen: {seen_special_tokens})") + seen_special_tokens.add(token_rep_llg) + self.special_token_ids.append(i) + self.tokens.append(token_rep_llg.encode("utf-8")) + else: + token_bytes = self._tokenizer.id_to_byte_piece(i, SpecialTokenPolicy.RAISE) + self.tokens.append(token_bytes) + + if len(self.special_token_ids) != self._tokenizer.num_special_tokens: + raise ValueError( + f"Expected {self._tokenizer.num_special_tokens} special tokens, but found {len(self.special_token_ids)}" + ) + + def __call__(self, s: str, *args: Any, **kwargs: Any) -> list[int]: + r"""Tokenizes a string into token ids. + + Args: + s: The string to tokenize. + *args: Additional positional arguments (ignored). + **kwargs: Additional keyword arguments (ignored). + + Returns: + The list of token ids. + """ + return self._tokenizer.encode(s, bos=False, eos=False) + + +def from_mistral_tokenizer(tokenizer: MistralTokenizer) -> "llg.LLTokenizer": + r"""Creates an llguidance tokenizer from a Mistral tokenizer. + + Args: + tokenizer: The Mistral tokenizer to convert. Must wrap a Tekkenizer. + + Returns: + The llguidance tokenizer. + + Raises: + TypeError: If the underlying tokenizer is not a Tekkenizer. + """ + assert_llguidance_installed() + inner_tokenizer = tokenizer.instruct_tokenizer.tokenizer + tokenizer_data = MistralLLGTokenizer(inner_tokenizer) + return llg.LLTokenizer(llg.TokenizerWrapper(tokenizer_data)) diff --git a/src/mistral_common/imports.py b/src/mistral_common/imports.py index 98047e36..820ef541 100644 --- a/src/mistral_common/imports.py +++ b/src/mistral_common/imports.py @@ -24,6 +24,16 @@ def is_hf_hub_installed() -> bool: return is_package_installed("huggingface_hub") +@lru_cache() +def is_jinja2_installed() -> bool: + return is_package_installed("jinja2") + + +@lru_cache() +def is_llguidance_installed() -> bool: + return is_package_installed("llguidance") + + @lru_cache() def is_opencv_installed() -> bool: try: @@ -59,21 +69,36 @@ def is_soxr_installed() -> bool: return is_package_installed("soxr") +@lru_cache() def assert_hf_hub_installed() -> None: assert_package_installed("huggingface_hub", _get_dependency_error_message("huggingface_hub", "hf-hub")) +@lru_cache() +def assert_jinja2_installed() -> None: + assert_package_installed("jinja2", _get_dependency_error_message("jinja2", "guidance")) + + +@lru_cache() +def assert_llguidance_installed() -> None: + assert_package_installed("llguidance", _get_dependency_error_message("llguidance", "guidance")) + + +@lru_cache() def assert_opencv_installed() -> None: assert_package_installed("cv2", _get_dependency_error_message("opencv", "opencv")) +@lru_cache() def assert_sentencepiece_installed() -> None: assert_package_installed("sentencepiece", _get_dependency_error_message("sentencepiece", "sentencepiece")) +@lru_cache() def assert_soundfile_installed() -> None: assert_package_installed("soundfile", _get_dependency_error_message("soundfile", "soundfile")) +@lru_cache() def assert_soxr_installed() -> None: assert_package_installed("soxr", _get_dependency_error_message("soxr", "soxr")) diff --git a/src/mistral_common/tokens/tokenizers/sentencepiece.py b/src/mistral_common/tokens/tokenizers/sentencepiece.py index ada1d0c3..a7325af7 100644 --- a/src/mistral_common/tokens/tokenizers/sentencepiece.py +++ b/src/mistral_common/tokens/tokenizers/sentencepiece.py @@ -3,6 +3,7 @@ import warnings from functools import cached_property from pathlib import Path +from typing import TypeGuard import numpy as np @@ -259,3 +260,8 @@ def pad_id(self) -> int: def unk_id(self) -> int: r"""The unknown token id.""" return self._model.unk_id() # type: ignore + + +def is_sentencepiece_tokenizer(tokenizer: Tokenizer) -> TypeGuard[SentencePieceTokenizer]: + r"""Returns whether the tokenizer is a SentencePieceTokenizer.""" + return isinstance(tokenizer, SentencePieceTokenizer) diff --git a/src/mistral_common/tokens/tokenizers/tekken.py b/src/mistral_common/tokens/tokenizers/tekken.py index 27b52cf3..85f27926 100644 --- a/src/mistral_common/tokens/tokenizers/tekken.py +++ b/src/mistral_common/tokens/tokenizers/tekken.py @@ -5,7 +5,7 @@ from functools import cached_property from itertools import groupby from pathlib import Path -from typing import TypedDict +from typing import TypedDict, TypeGuard import numpy as np import tiktoken @@ -545,3 +545,8 @@ def _reload_mergeable_ranks( assert set(ranks.values()) == set(range(len(ranks))) return ranks + + +def is_tekkenizer(tokenizer: Tokenizer) -> TypeGuard[Tekkenizer]: + r"""Returns whether the tokenizer is a Tekkenizer.""" + return isinstance(tokenizer, Tekkenizer) diff --git a/tests/data/emoji.lark b/tests/data/emoji.lark new file mode 100644 index 00000000..6222e8f0 --- /dev/null +++ b/tests/data/emoji.lark @@ -0,0 +1,2 @@ +start: emoji+ +emoji: /(\p{Emoji_Presentation}|\p{Emoji}\uFE0F)/ diff --git a/tests/guidance/__init__.py b/tests/guidance/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/guidance/test_guidance.py b/tests/guidance/test_guidance.py new file mode 100644 index 00000000..1699c837 --- /dev/null +++ b/tests/guidance/test_guidance.py @@ -0,0 +1,1021 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Literal + +import pytest +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from mistral_common.guidance.grammar_factory import GrammarFactory +from mistral_common.protocol.instruct.chunk import TextChunk, ThinkChunk +from mistral_common.protocol.instruct.messages import AssistantMessage +from mistral_common.protocol.instruct.normalize import get_normalizer +from mistral_common.protocol.instruct.tool_calls import Function, FunctionCall, Tool, ToolCall, ToolChoice +from mistral_common.protocol.instruct.validator import ValidationMode, get_validator +from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, Tokenizer, TokenizerVersion +from mistral_common.tokens.tokenizers.instruct import ( + InstructTokenizerBase, + InstructTokenizerV11, + InstructTokenizerV13, +) +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from mistral_common.tokens.tokenizers.tekken import Tekkenizer, is_tekkenizer +from tests.test_tekken import get_special_tokens, quick_vocab + +EMOJI_LARK_PATH = Path(__file__).parent.parent / "data" / "emoji.lark" + +Mode = Literal["auto", "any", "none"] +_AUTO_ANY: tuple[Mode, Mode] = ("auto", "any") + + +_NUM_SPECIAL_TOKENS = 100 +_EXTRA_TOKENS = [b"a", b"b", b"c", b"f", b"de", b"he", b"llo"] + + +def _build_tekken_mistral_tokenizer( + version: TokenizerVersion, + add_think: bool = False, +) -> MistralTokenizer: + r"""Builds a MistralTokenizer wrapping a programmatic Tekkenizer.""" + special_tokens = get_special_tokens(version, add_think=add_think) + + tekkenizer = Tekkenizer( + quick_vocab(_EXTRA_TOKENS), + special_tokens=special_tokens, + pattern=r".+", + vocab_size=256 + _NUM_SPECIAL_TOKENS, + num_special_tokens=_NUM_SPECIAL_TOKENS, + version=version, + ) + + match version: + case TokenizerVersion.v11: + instruct_tokenizer = InstructTokenizerV11(tekkenizer) + case TokenizerVersion.v13: + instruct_tokenizer = InstructTokenizerV13(tekkenizer) + case _: + raise ValueError(f"Unsupported version for programmatic Tekken build: {version}") + + normalizer = get_normalizer(version, tekkenizer.model_settings_builder) + validator = get_validator(version, mode=ValidationMode.test) + return MistralTokenizer(instruct_tokenizer, validator=validator, request_normalizer=normalizer) + + +@pytest.fixture(scope="module") +def v11_tekken() -> MistralTokenizer: + return _build_tekken_mistral_tokenizer(TokenizerVersion.v11) + + +@pytest.fixture(scope="module") +def v13_tekken() -> MistralTokenizer: + return _build_tekken_mistral_tokenizer(TokenizerVersion.v13, add_think=True) + + +_PAYMENT_PARAMS: dict[str, Any] = { + "type": "object", + "additionalProperties": False, + "properties": {"transaction_id": {"type": "string", "description": "The transaction id."}}, + "required": ["transaction_id"], +} + + +class ToolProvider: + @staticmethod + def retrieve_payment_status(strict: bool) -> Tool: + return Tool( + function=Function( + name="retrieve_payment_status", + description="Get payment status of a transaction", + strict=strict, + parameters=_PAYMENT_PARAMS, + ) + ) + + @staticmethod + def retrieve_payment_date(strict: bool) -> Tool: + return Tool( + function=Function( + name="retrieve_payment_date", + description="Get payment date of a transaction", + strict=strict, + parameters=_PAYMENT_PARAMS, + ) + ) + + +class SchemaProvider: + @staticmethod + def basic_person() -> dict[str, Any]: + class Person(BaseModel): + name: str + age: int + + class Config: + extra = "forbid" + + return Person.model_json_schema() + + @staticmethod + def basic_dict_of_list() -> dict[str, Any]: + class DoMerge(BaseModel): + new_clusters: dict[str, list[str]] = Field(default_factory=dict) + + class Config: + extra = "forbid" + + return DoMerge.model_json_schema() + + +class TestCase(BaseModel): + __test__ = False + model_config = ConfigDict(arbitrary_types_allowed=True) + tokenizer: Tokenizer + mode: Literal["auto", "any", "none"] + tokens: list[int] + should_fail_on: int | None + case_name: str + reasoning: bool = False + parallel_tool_calls: bool = True + tools: list[Tool] | None = None + json_schema: dict[str, Any] | None = None + # When set, uses this raw lark grammar instead of GrammarFactory.get_lark + raw_lark: str | None = None + + @model_validator(mode="after") + def validate_should_fail_on(self) -> TestCase: + if self.should_fail_on is not None and self.should_fail_on < 0: + self.should_fail_on += len(self.tokens) + if self.should_fail_on < 0: + raise ValueError( + f"should_fail_on={self.should_fail_on + len(self.tokens)} " + f"is out of bounds for tokens of length {len(self.tokens)}" + ) + return self + + @property + def name(self) -> str: + return f"tekken_{self.tokenizer.version.value}_{self.mode}_{self.case_name}" + + +def _encode_content( + instruct_tokenizer: InstructTokenizerBase, + content: str | list[Any], +) -> list[int]: + r"""Encodes assistant message content into token ids. + + Args: + instruct_tokenizer: The instruct tokenizer to use. + content: Either a plain string or a list of TextChunk / ThinkChunk / ToolCall. + + Returns: + The encoded token ids. + """ + tokenizer = instruct_tokenizer.tokenizer + + if isinstance(content, str): + return instruct_tokenizer.encode_assistant_message( + AssistantMessage(content=content), is_before_last_user_message=False, continue_message=False + ) + + tool_calls = [x for x in content if isinstance(x, ToolCall)] + content_chunks = [x for x in content if not isinstance(x, ToolCall)] + + tokens: list[int] = [] + if content_chunks: + tokens = instruct_tokenizer.encode_assistant_message( + AssistantMessage(content=content_chunks), + is_before_last_user_message=False, + continue_message=False, + ) + # Strip trailing EOS before adding tool calls + while tokens and tokens[-1] == tokenizer.eos_id: + tokens.pop() + + for tc in tool_calls: + tokens += [ + tokenizer.get_special_token("[TOOL_CALLS]"), + *tokenizer.encode(tc.function.name, bos=False, eos=False), + tokenizer.get_special_token("[ARGS]"), + *tokenizer.encode(tc.function.arguments, bos=False, eos=False), + ] + tokens.append(tokenizer.eos_id) + return tokens + + +def _find_first_rejection( + factory: GrammarFactory, + tokens: list[int], + mode: Mode, + tools: list[Tool] | None, +) -> int: + r"""Finds the index of the first token rejected by the grammar. + + Args: + factory: The grammar factory. + tokens: The token sequence to test. + mode: The tool choice mode. + tools: The tools to pass to grammar generation. + + Returns: + The index of the first rejected token. + + Raises: + ValueError: If all tokens are accepted. + """ + grammar = factory.get_lark( + reasoning=False, mode=ToolChoice(mode), tools=tools, json_schema=None, parallel_tool_calls=True + ) + matcher = factory.get_matcher(grammar) + for i, token in enumerate(tokens): + if not matcher.consume_token(token): + return i + raise ValueError("All tokens were accepted β€” expected a rejection") + + +def _token_debug_repr(tokenizer: Tekkenizer, token_id: int) -> str: + r"""Returns a debug representation of a token.""" + return repr(tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)) + + +def _generate_general_cases(mistral_tokenizer: MistralTokenizer) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + # For programmatic tiny-vocab tokenizers, limit to ASCII-safe content + is_full_vocab = tokenizer.n_words > 1000 + if is_full_vocab: + items = { + "newline": "\n", + "blank": "_", + "text": "Hello!", + "text_with_newlines": "Hello!\n\nHow are you?\nI'm fine, thanks!", + "emojis": "πŸ˜ƒπŸ˜‚πŸ˜ŠπŸ˜πŸ˜˜πŸ˜—πŸ˜™πŸ˜šπŸ˜‹πŸ˜›πŸ˜œπŸ˜πŸ€‘πŸ€—πŸ€”πŸ€πŸ˜πŸ˜‘πŸ˜ΆπŸ˜¬", + "japanese": "こんにけは", + "arabic": "Ω…Ψ±Ψ­Ψ¨Ψ§ Ψ¨ΩƒΩ… في ΨΉΨ§Ω„Ω… Ψ§Ω„Ψ°ΩƒΨ§Ψ‘ Ψ§Ω„Ψ§Ψ΅Ψ·Ω†Ψ§ΨΉΩŠ", + } + else: + items = { + "text_a": "abc", + "text_b": "hello", + "text_de": "de", + } + for case_name, content in items.items(): + tokens = _encode_content(instruct_tokenizer, content) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + case_name=case_name, + should_fail_on=None, + mode="auto", + reasoning=False, + ) + ) + # v11: plain_text_think grammar requires ... first, so plain text fails + # v13+: think grammar has think? (optional), so plain text passes + reasoning_fail = 0 if tokenizer.version < TokenizerVersion.v13 else None + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + case_name=f"{case_name}_reasoning", + should_fail_on=reasoning_fail, + mode="auto", + reasoning=True, + ) + ) + return cases + + +def _generate_emoji_cases(mistral_tokenizer: MistralTokenizer) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + # Emoji grammar needs a full vocab tokenizer (tiny programmatic tokenizers lack emoji tokens) + if tokenizer.n_words <= 1000: + return [] + + emoji_lark = EMOJI_LARK_PATH.read_text(encoding="utf-8") + cases: list[TestCase] = [] + # Encode raw emoji text (no assistant message wrapping) so we get pure emoji tokens + items: dict[str, tuple[str, int | None]] = { + "emojis_valid_a": ("πŸ˜ƒπŸ˜‚πŸ˜ŠπŸ˜πŸ˜˜πŸ˜—πŸ˜™πŸ˜šπŸ˜‹πŸ˜›πŸ˜œπŸ˜πŸ€‘πŸ€—πŸ€”πŸ€πŸ˜πŸ˜‘πŸ˜ΆπŸ˜¬", None), + "emojis_valid_b": ("πŸ˜ƒπŸ˜ƒπŸ˜ƒ", None), + "emojis_invalid_text": ("πŸ˜ƒsmile", len(tokenizer.encode("πŸ˜ƒ", bos=False, eos=False))), + "emojis_invalid_space": ("πŸ˜ƒ ", len(tokenizer.encode("πŸ˜ƒ", bos=False, eos=False))), + } + for case_name, (text, should_fail_on) in items.items(): + # Use raw encode to avoid assistant message wrapping (BOS/EOS) + tokens = tokenizer.encode(text, bos=False, eos=False) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + case_name=case_name, + should_fail_on=should_fail_on, + mode="auto", + raw_lark=emoji_lark, + ) + ) + return cases + + +def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + items: list[tuple[str, list[ToolCall], dict[Mode, int | None]]] = [ + ( + "single_fcall", + [ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}'))], + {"auto": None, "any": None, "none": 0}, + ), + ( + "multi_fcall", + [ + ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}')), + ToolCall(function=FunctionCall(name="hello_1", arguments='{"arg1": "val1", "arg2": "val2"}')), + ToolCall(function=FunctionCall(name="hello_2_3", arguments='{"arg1": "val1", "arg2": "val2"}')), + ], + {"auto": None, "any": None, "none": 0}, + ), + ( + "emoji_fcall", + [ + ToolCall( + function=FunctionCall(name="he🧦🧦o", arguments='{"arg1": "🐱", "arg2": "🐢", "arg🧦": "🧦"}'), + ) + ], + {"auto": None, "any": None, "none": 0}, + ), + ( + "pretty_printed_args", + [ + ToolCall( + function=FunctionCall( + name="hello", + arguments='{\n "arg1": "val1",\n "arg2": "val2"\n }\n', + ), + ) + ], + {"auto": None, "any": None, "none": 0}, + ), + ( + "japanese_fcall", + [ToolCall(function=FunctionCall(name="こんにけは", arguments='{"こん": "にけは"}'))], + {"auto": None, "any": None, "none": 0}, + ), + ] + for case_name, content, valid_for in items: + tokens = _encode_content(instruct_tokenizer, content) + for mode, should_fail_on in valid_for.items(): + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=should_fail_on, + case_name=case_name, + mode=mode, + ) + ) + + # v11: plain_text_think mandates first, so bare tool calls fail at 0 + # v13+: think grammar has think? (optional), tool calls pass; fcalls allows content? prefix + if tokenizer.version < TokenizerVersion.v13: + reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "any": 0, "none": 0} + else: + reasoning_valid_for = {"auto": None, "any": None, "none": 0} + for mode, should_fail_on in reasoning_valid_for.items(): + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=should_fail_on, + case_name=f"{case_name}_reasoning", + mode=mode, + reasoning=True, + ) + ) + + # Broken / missing args edge cases (token-level construction) + token_items: list[tuple[str, list[int], dict[Mode, int | None]]] = [ + ( + "fcall_broken_args", + [ + tokenizer.get_special_token("[TOOL_CALLS]"), + *tokenizer.encode("hello", bos=False, eos=False), + tokenizer.get_special_token("[ARGS]"), + *tokenizer.encode('{"a', bos=False, eos=False), + tokenizer.eos_id, + ], + {"auto": -1, "any": -1, "none": 0}, + ), + ( + "fcall_missing_args", + [ + tokenizer.get_special_token("[TOOL_CALLS]"), + *tokenizer.encode("hello", bos=False, eos=False), + tokenizer.get_special_token("[ARGS]"), + tokenizer.get_special_token("[TOOL_CALLS]"), + ], + {"auto": -1, "any": -1, "none": 0}, + ), + ] + + for case_name, tokens, valid_for in token_items: + for mode, should_fail_on in valid_for.items(): + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=should_fail_on, + case_name=case_name, + mode=mode, + ) + ) + return cases + + +def _generate_cases_text_and_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + content = [ + TextChunk(text="Hello!"), + ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}')), + ] + tokens = _encode_content(instruct_tokenizer, content) + text_len = len(tokenizer.encode("Hello!", bos=False, eos=False)) + + # Non-reasoning uses base grammar where "any" mode is `body: fcalls` β€” no content allowed + valid_for: dict[Mode, int | None] = {"auto": None, "any": 0, "none": text_len} + + for mode, should_fail_on in valid_for.items(): + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=should_fail_on, + case_name="text_fcall", + mode=mode, + ) + ) + + # v11: plain_text_think mandates first, so text+fcall without think fails at 0 + # v13+: think? optional, fcalls: content? fcall, so "any" mode accepts text before tool calls + if tokenizer.version < TokenizerVersion.v13: + reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "any": 0, "none": 0} + else: + reasoning_valid_for = {"auto": None, "any": None, "none": text_len} + for mode, should_fail_on in reasoning_valid_for.items(): + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=should_fail_on, + case_name="text_fcall_reasoning", + mode=mode, + reasoning=True, + ) + ) + + return cases + + +def _generate_cases_thinking_v11(mistral_tokenizer: MistralTokenizer) -> list[TestCase]: + r"""Generate thinking test cases for v11 (plain text think grammar).""" + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + thinks: list[tuple[str, list[Any], dict[Mode, int | None]]] = [ + ( + "force_think", + [TextChunk(text="Hello world!")], + {"auto": 0, "any": 0, "none": 0}, + ), + ( + "think_without_response", + [TextChunk(text="Hello!")], + {"auto": -1, "any": -1, "none": -1}, + ), + ( + "unclosed_think", + [TextChunk(text="Hello!")], + {"auto": -1, "any": -1, "none": -1}, + ), + ( + "plain_think_with_response", + [TextChunk(text="Hello!World!")], + { + "auto": None, + "any": len(tokenizer.encode("Hello!", bos=False, eos=False)), + "none": None, + }, + ), + ( + "think_with_tool_call", + [ + TextChunk(text="Hello!"), + ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}')), + ], + { + "auto": None, + "any": None, + "none": len(tokenizer.encode("Hello!", bos=False, eos=False)), + }, + ), + ( + "think_with_text_and_tool_call", + [ + TextChunk(text="Hello!Ho!"), + ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}')), + ], + { + "auto": len(tokenizer.encode("Hello!Ho!", bos=False, eos=False)), + "any": len(tokenizer.encode("Hello!", bos=False, eos=False)), + "none": len(tokenizer.encode("Hello!Ho!", bos=False, eos=False)), + }, + ), + ] + for case_name, content, valid_for in thinks: + tokens = _encode_content(instruct_tokenizer, content) + for mode, should_fail_on in valid_for.items(): + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=should_fail_on, + case_name=case_name, + mode=mode, + reasoning=True, + ) + ) + return cases + + +def _generate_cases_thinking(mistral_tokenizer: MistralTokenizer) -> list[TestCase]: + r"""Generate thinking test cases for v13+ (structured think grammar with [THINK]/[/THINK] tokens).""" + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + + def _think_tokens(text: str) -> list[int]: + r"""Helper to encode a ThinkChunk.""" + assert isinstance(instruct_tokenizer, InstructTokenizerV13) + return instruct_tokenizer.encode_think(ThinkChunk(thinking=text)) + + thinks: list[tuple[str, list[Any], dict[Mode, int | None]]] = [ + ( + "plain_text", + [TextChunk(text="Hello world!")], + {"auto": None, "any": -1, "none": None}, + ), + ( + "plain_think", + [ThinkChunk(thinking="Hello!")], + {"auto": -1, "any": -1, "none": -1}, + ), + ( + "plain_think_with_response", + [ThinkChunk(thinking="Hello!"), TextChunk(text="World!")], + {"auto": None, "any": -1, "none": None}, + ), + ( + "think_with_tool_call", + [ + ThinkChunk(thinking="Hello!"), + ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}')), + ], + { + "auto": None, + "any": None, + "none": len(_think_tokens("Hello!")), + }, + ), + ( + "think_text_tool_call", + [ + ThinkChunk(thinking="Hello!"), + TextChunk(text="World!"), + ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}')), + ], + { + "auto": None, + "any": None, + "none": len(_think_tokens("Hello!")) + len(tokenizer.encode("World!", bos=False, eos=False)), + }, + ), + ] + for case_name, content, valid_for in thinks: + tokens = _encode_content(instruct_tokenizer, content) + for mode, should_fail_on in valid_for.items(): + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=should_fail_on, + case_name=case_name, + mode=mode, + reasoning=True, + ) + ) + return cases + + +def _generate_single_tool_call(mistral_tokenizer: MistralTokenizer) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + single_call = [ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}'))] + single_tokens = _encode_content(instruct_tokenizer, single_call) + for mode in _AUTO_ANY: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=single_tokens, + should_fail_on=None, + case_name="single_tool_call", + mode=mode, + parallel_tool_calls=False, + ) + ) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=single_tokens, + should_fail_on=0, + case_name="single_tool_call", + mode="none", + parallel_tool_calls=False, + ) + ) + + # v11: plain_text_think mandates first, so bare tool calls fail at 0 + # v13+: think? optional, tool calls pass + if tokenizer.version < TokenizerVersion.v13: + reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "any": 0, "none": 0} + else: + reasoning_valid_for = {"auto": None, "any": None, "none": 0} + for mode, should_fail_on in reasoning_valid_for.items(): + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=single_tokens, + should_fail_on=should_fail_on, + case_name="single_tool_call_reasoning", + mode=mode, + parallel_tool_calls=False, + reasoning=True, + ) + ) + + # Multi tool call should fail when parallel_tool_calls=False + # Each tool call is a separate [TOOL_CALLS]...[ARGS]... sequence; + # the second [TOOL_CALLS] is where it fails. + multi_calls = [ + ToolCall(function=FunctionCall(name="fn1", arguments='{"arg1": "val1", "arg2": "val2"}')), + ToolCall(function=FunctionCall(name="fn2", arguments='{"arg1": "val1", "arg2": "val2"}')), + ] + multi_tokens = _encode_content(instruct_tokenizer, multi_calls) + single_tokens_with_eos = _encode_content(instruct_tokenizer, [multi_calls[0]]) + fail_idx = len(single_tokens_with_eos) - 1 + + for mode in _AUTO_ANY: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=multi_tokens, + should_fail_on=fail_idx, + case_name="multi_tool_call_disallowed", + mode=mode, + parallel_tool_calls=False, + ) + ) + + return cases + + +def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: GrammarFactory) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + tools_strict = [ToolProvider.retrieve_payment_date(strict=True)] + + # 1. Non-strict tools β€” any function name/args accepted + non_strict_call = [ToolCall(function=FunctionCall(name="fn1", arguments='{"arg1": "val1", "arg2": "val2"}'))] + non_strict_tokens = _encode_content(instruct_tokenizer, non_strict_call) + for mode in _AUTO_ANY: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=non_strict_tokens, + should_fail_on=None, + case_name="single_non_strict_tool_call", + mode=mode, + tools=[ToolProvider.retrieve_payment_date(strict=False)], + ) + ) + + # 2. Correct strict tool call + strict_call = [ + ToolCall(function=FunctionCall(name="retrieve_payment_date", arguments='{"transaction_id": "12345"}')) + ] + strict_tokens = _encode_content(instruct_tokenizer, strict_call) + for mode in _AUTO_ANY: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=strict_tokens, + should_fail_on=None, + case_name="single_strict_tool_call", + mode=mode, + tools=[ToolProvider.retrieve_payment_date(strict=True)], + ) + ) + + # 3. Wrong args for strict tool β€” it must fail somewhere before the end. + wrong_args_call = [ToolCall(function=FunctionCall(name="retrieve_payment_date", arguments='{"bogus": "12345"}'))] + wrong_args_tokens = _encode_content(instruct_tokenizer, wrong_args_call) + bogus_start = _find_first_rejection(factory, wrong_args_tokens, mode="auto", tools=tools_strict) + for mode in _AUTO_ANY: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=wrong_args_tokens, + should_fail_on=bogus_start, + case_name="strict_tool_call_wrong_args", + mode=mode, + tools=tools_strict, + ) + ) + + # 4. Wrong name for strict tool + wrong_name_call = [ToolCall(function=FunctionCall(name="fn1", arguments='{"transaction_id": "12345"}'))] + wrong_name_tokens = _encode_content(instruct_tokenizer, wrong_name_call) + fail_on_name = _find_first_rejection(factory, wrong_name_tokens, mode="auto", tools=tools_strict) + for mode in _AUTO_ANY: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=wrong_name_tokens, + should_fail_on=fail_on_name, + case_name="strict_tool_call_wrong_name", + mode=mode, + tools=tools_strict, + ) + ) + + # 5. Multiple strict tool calls (both correct) + multi_strict = [ + ToolCall(function=FunctionCall(name="retrieve_payment_date", arguments='{"transaction_id": "12345"}')), + ToolCall(function=FunctionCall(name="retrieve_payment_status", arguments='{"transaction_id": "12345"}')), + ] + multi_strict_tokens = _encode_content(instruct_tokenizer, multi_strict) + for mode in _AUTO_ANY: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=multi_strict_tokens, + should_fail_on=None, + case_name="multiple_strict_tool_calls", + mode=mode, + tools=[ + ToolProvider.retrieve_payment_date(strict=True), + ToolProvider.retrieve_payment_status(strict=True), + ], + ) + ) + + # 6. reasoning=True variants + # v11: plain_text_think mandates first, so all bare tool calls fail at 0 + # v13+: think? optional, tool calls behave as without reasoning + if tokenizer.version < TokenizerVersion.v13: + for mode in _AUTO_ANY: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=strict_tokens, + should_fail_on=0, + case_name="strict_tool_call_reasoning", + mode=mode, + tools=tools_strict, + reasoning=True, + ) + ) + else: + for mode in _AUTO_ANY: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=strict_tokens, + should_fail_on=None, + case_name="strict_tool_call_reasoning", + mode=mode, + tools=tools_strict, + reasoning=True, + ) + ) + for mode in _AUTO_ANY: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=multi_strict_tokens, + should_fail_on=None, + case_name="multiple_strict_tool_calls_reasoning", + mode=mode, + tools=[ + ToolProvider.retrieve_payment_date(strict=True), + ToolProvider.retrieve_payment_status(strict=True), + ], + reasoning=True, + ) + ) + + return cases + + +def _generate_json_schema(mistral_tokenizer: MistralTokenizer) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + # JSON schema validation requires a full vocab tokenizer + if tokenizer.n_words <= 1000: + return [] + + cases: list[TestCase] = [] + items: list[tuple[str, str, int | None, dict[str, Any]]] = [ + ( + "basic_person_valid", + '{"name": "John", "age": 30}', + None, + SchemaProvider.basic_person(), + ), + ( + "invalid_json_missing_curly_bracket", + '"name": "John", "age": 30}', + 0, + {"type": "object"}, + ), + ( + "valid_json_white_spaces", + '\n {"name": "John", "age": 30}', + None, + {"type": "object"}, + ), + ( + "invalid_json_backslash_f", + '\f{"name": "John", "age": 30}', + 0, + {"type": "object"}, + ), + ( + "basic_person_invalid", + '{"age": "John", "name": 30}', + 1, + SchemaProvider.basic_person(), + ), + ( + "basic_person_non_strict_valid", + '{"age": "John", "name": 30}', + None, + {"type": "object"}, + ), + ( + "domerge_valid", + '{"new_clusters": {"b": ["a", "b", "c"], "d": ["e"]} }', + None, + SchemaProvider.basic_dict_of_list(), + ), + ] + + for case_name, text, should_fail_on, json_schema in items: + tokens = _encode_content(instruct_tokenizer, text) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=should_fail_on, + case_name=case_name, + mode="auto", + json_schema=json_schema, + ) + ) + return cases + + +def _generate_cases(mistral_tokenizer: MistralTokenizer, factory: GrammarFactory) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + tokenizer_version = instruct_tokenizer.tokenizer.version + + cases = _generate_general_cases(mistral_tokenizer) + cases += _generate_emoji_cases(mistral_tokenizer) + cases += _generate_json_schema(mistral_tokenizer) + cases += _generate_cases_tool_calls(mistral_tokenizer) + cases += _generate_single_tool_call(mistral_tokenizer) + cases += _generate_strict_tool_calls(mistral_tokenizer, factory) + cases += _generate_cases_text_and_tool_calls(mistral_tokenizer) + + if not tokenizer_version < TokenizerVersion.v13: + cases += _generate_cases_thinking(mistral_tokenizer) + else: + cases += _generate_cases_thinking_v11(mistral_tokenizer) + + return cases + + +_grammar_factories: dict[int, GrammarFactory] = {} + + +def _get_grammar_factory(mistral_tokenizer: MistralTokenizer) -> GrammarFactory: + tok_id = id(mistral_tokenizer) + if tok_id not in _grammar_factories: + _grammar_factories[tok_id] = GrammarFactory(mistral_tokenizer) + return _grammar_factories[tok_id] + + +_ALL_TOKENIZERS: list[MistralTokenizer] = [ + _build_tekken_mistral_tokenizer(TokenizerVersion.v11), + _build_tekken_mistral_tokenizer(TokenizerVersion.v13, add_think=True), +] + +_ALL_CASES: list[TestCase] = [] +_ALL_MISTRAL_TOKENIZERS: dict[int, MistralTokenizer] = {} +for mistral_tokenizer in _ALL_TOKENIZERS: + tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer + _ALL_MISTRAL_TOKENIZERS[id(tokenizer)] = mistral_tokenizer + factory = _get_grammar_factory(mistral_tokenizer) + _ALL_CASES.extend(_generate_cases(mistral_tokenizer, factory)) + + +class TestGrammarFactory: + @pytest.mark.parametrize("test_case", _ALL_CASES, ids=lambda tc: tc.name) + def test_grammar(self, test_case: TestCase) -> None: + mistral_tokenizer = _ALL_MISTRAL_TOKENIZERS[id(test_case.tokenizer)] + factory = _get_grammar_factory(mistral_tokenizer) + + if test_case.raw_lark is not None: + grammar = test_case.raw_lark + elif test_case.json_schema is not None and test_case.tools is None: + grammar = factory.get_lark_for_json_schema(json_schema=test_case.json_schema) + else: + grammar = factory.get_lark( + reasoning=test_case.reasoning, + mode=ToolChoice(test_case.mode), + tools=test_case.tools, + json_schema=test_case.json_schema, + parallel_tool_calls=test_case.parallel_tool_calls, + ) + + matcher = factory.get_matcher(grammar) + + assert is_tekkenizer(test_case.tokenizer) + debug_tokens = [_token_debug_repr(test_case.tokenizer, t) for t in test_case.tokens] + for i, token in enumerate(test_case.tokens): + debug_bytes = _token_debug_repr(test_case.tokenizer, token) + if token != test_case.tokenizer.eos_id: + assert not matcher.is_stopped(), ( + f"Matcher stopped before token {i} id={token}\n\n" + f"Grammar:\n{grammar}\n\nTokens: {debug_tokens}\n\nIds: {test_case.tokens}" + ) + accepted = matcher.consume_token(token) + if i == test_case.should_fail_on: + if accepted: + raise AssertionError( + f"Token {token}={debug_bytes} at pos {i} was accepted but should have been rejected." + f"\n\nGrammar:\n{grammar}\n\nTokens: {debug_tokens}\n\nIds: {test_case.tokens}" + ) + break + elif not accepted: + raise AssertionError( + f"Token {token}={debug_bytes} at pos {i} was rejected but should have been accepted." + f"\n\nGrammar:\n{grammar}\n\nTokens: {debug_tokens}\n\nIds: {test_case.tokens}" + ) + + @pytest.mark.parametrize( + ("tokenizer", "expected"), + [ + (MistralTokenizer.v1(), False), + (MistralTokenizer.v3(is_tekken=True), False), + (_build_tekken_mistral_tokenizer(TokenizerVersion.v11), True), + (_build_tekken_mistral_tokenizer(TokenizerVersion.v13, add_think=True), True), + ], + ) + def test_grammar_factory_is_supported(self, tokenizer: MistralTokenizer, expected: bool) -> None: + assert GrammarFactory.is_supported(tokenizer) is expected diff --git a/tests/guidance/test_tokenizer.py b/tests/guidance/test_tokenizer.py new file mode 100644 index 00000000..d302ddf3 --- /dev/null +++ b/tests/guidance/test_tokenizer.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import re +from unittest.mock import MagicMock + +import llguidance as llg +import pytest + +from mistral_common.guidance.tokenizer import MistralLLGTokenizer, from_mistral_tokenizer +from mistral_common.protocol.instruct.normalize import get_normalizer +from mistral_common.protocol.instruct.validator import ValidationMode, get_validator +from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, SpecialTokens, Tokenizer, TokenizerVersion +from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV7 +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from mistral_common.tokens.tokenizers.tekken import Tekkenizer +from tests.test_tekken import get_special_tokens, quick_vocab + +_NUM_SPECIAL_TOKENS = 100 +_EXTRA_TOKENS = [b"a", b"b", b"c", b"f", b"de", b"he", b"llo"] + + +@pytest.fixture(scope="module") +def tekkenizer() -> Tekkenizer: + special_tokens = get_special_tokens(TokenizerVersion.v7) + return Tekkenizer( + quick_vocab(_EXTRA_TOKENS), + special_tokens=special_tokens, + pattern=r".+", + vocab_size=256 + _NUM_SPECIAL_TOKENS, + num_special_tokens=_NUM_SPECIAL_TOKENS, + version=TokenizerVersion.v7, + ) + + +@pytest.fixture(scope="module") +def llg_tokenizer(tekkenizer: Tekkenizer) -> MistralLLGTokenizer: + return MistralLLGTokenizer(tekkenizer) + + +@pytest.fixture(scope="module") +def mistral_tokenizer(tekkenizer: Tekkenizer) -> MistralTokenizer: + instruct_tokenizer = InstructTokenizerV7(tekkenizer) + normalizer = get_normalizer(tekkenizer.version, tekkenizer.model_settings_builder) + validator = get_validator(tekkenizer.version, mode=ValidationMode.test) + return MistralTokenizer(instruct_tokenizer, validator=validator, request_normalizer=normalizer) + + +@pytest.fixture(scope="module") +def ll_tokenizer(mistral_tokenizer: MistralTokenizer) -> llg.LLTokenizer: + return from_mistral_tokenizer(mistral_tokenizer) + + +class TestMistralLLGTokenizer: + def test_init_rejects_non_tekkenizer(self) -> None: + mock_tokenizer = MagicMock(spec=Tokenizer) + with pytest.raises(TypeError, match="Guidance only supports Tekken tokenizers"): + MistralLLGTokenizer(mock_tokenizer) + + def test_eos_and_bos_ids(self, tekkenizer: Tekkenizer, llg_tokenizer: MistralLLGTokenizer) -> None: + assert llg_tokenizer.eos_token_id == tekkenizer.eos_id + assert llg_tokenizer.bos_token_id == tekkenizer.bos_id + + def test_tokens_length(self, tekkenizer: Tekkenizer, llg_tokenizer: MistralLLGTokenizer) -> None: + assert len(llg_tokenizer.tokens) == tekkenizer.n_words + + def test_special_token_ids_count(self, tekkenizer: Tekkenizer, llg_tokenizer: MistralLLGTokenizer) -> None: + assert len(llg_tokenizer.special_token_ids) == tekkenizer.num_special_tokens + + def test_all_special_tokens_are_angle_bracketed(self, llg_tokenizer: MistralLLGTokenizer) -> None: + for i in llg_tokenizer.special_token_ids: + token_bytes = llg_tokenizer.tokens[i] + token_str = token_bytes.decode("utf-8") + assert re.fullmatch(r"<.*>", token_str), f"Special token at id={i} is not angle-bracketed: {token_str!r}" + + def test_special_token_conversion(self, tekkenizer: Tekkenizer, llg_tokenizer: MistralLLGTokenizer) -> None: + checked = 0 + for token in SpecialTokens: + try: + rank = tekkenizer.get_special_token(token.value) + except ValueError: + continue + + llg_bytes = llg_tokenizer.tokens[rank] + if re.fullmatch(r"\[.*\]", token.value): + expected = token.value.replace("[", "<").replace("]", ">") + else: + expected = token.value + + assert llg_bytes == expected.encode("utf-8"), ( + f"Token {token.name} at rank={rank}: expected {expected!r}, got {llg_bytes!r}" + ) + checked += 1 + + # Filler tokens () should be preserved as-is + for i in range(tekkenizer.num_special_tokens): + piece = tekkenizer.id_to_piece(i) + if re.fullmatch(r"", piece): + assert llg_tokenizer.tokens[i] == piece.encode("utf-8"), ( + f"Filler token at id={i}: expected {piece!r}, got {llg_tokenizer.tokens[i]!r}" + ) + checked += 1 + + assert checked == tekkenizer.num_special_tokens + + def test_non_special_tokens_match_byte_pieces( + self, tekkenizer: Tekkenizer, llg_tokenizer: MistralLLGTokenizer + ) -> None: + for i in range(tekkenizer.num_special_tokens, tekkenizer.n_words): + expected = tekkenizer.id_to_byte_piece(i, SpecialTokenPolicy.RAISE) + assert llg_tokenizer.tokens[i] == expected, ( + f"Token at id={i}: expected {expected!r}, got {llg_tokenizer.tokens[i]!r}" + ) + + def test_call_encodes_string(self, tekkenizer: Tekkenizer, llg_tokenizer: MistralLLGTokenizer) -> None: + test_string = "abc" + assert llg_tokenizer(test_string) == tekkenizer.encode(test_string, bos=False, eos=False) + + +class TestFromMistralTokenizer: + def test_properties(self, tekkenizer: Tekkenizer, ll_tokenizer: llg.LLTokenizer) -> None: + assert isinstance(ll_tokenizer, llg.LLTokenizer) + assert ll_tokenizer.vocab_size == tekkenizer.n_words + assert ll_tokenizer.eos_token == tekkenizer.eos_id + for i in range(tekkenizer.num_special_tokens): + assert ll_tokenizer.is_special_token(i), f"Token id={i} should be special" + + def test_tokenize_str_matches_tekkenizer(self, tekkenizer: Tekkenizer, ll_tokenizer: llg.LLTokenizer) -> None: + test_strings = ["abc", "hello", "de", "abcdefhello"] + for s in test_strings: + expected = tekkenizer.encode(s, bos=False, eos=False) + result = ll_tokenizer.tokenize_str(s) + assert result == expected, f"Mismatch for {s!r}: expected {expected}, got {result}" + + def test_decode_str_roundtrip(self, tekkenizer: Tekkenizer, ll_tokenizer: llg.LLTokenizer) -> None: + test_strings = ["abc", "hello", "de"] + for s in test_strings: + tokens = tekkenizer.encode(s, bos=False, eos=False) + decoded = ll_tokenizer.decode_str(tokens) + assert decoded == s, f"Roundtrip failed for {s!r}: got {decoded!r}" + + def test_decode_bytes_roundtrip(self, tekkenizer: Tekkenizer, ll_tokenizer: llg.LLTokenizer) -> None: + test_strings = ["abc", "hello"] + for s in test_strings: + tokens = tekkenizer.encode(s, bos=False, eos=False) + decoded = ll_tokenizer.decode_bytes(tokens) + assert decoded == s.encode("utf-8"), f"Roundtrip failed for {s!r}: got {decoded!r}" + + def test_decode_special_tokens(self, tekkenizer: Tekkenizer, ll_tokenizer: llg.LLTokenizer) -> None: + for token in SpecialTokens: + try: + rank = tekkenizer.get_special_token(token.value) + except ValueError: + continue + + if re.fullmatch(r"\[.*\]", token.value): + expected = token.value.replace("[", "<").replace("]", ">") + else: + expected = token.value + + decoded_str = ll_tokenizer.decode_str([rank]) + assert decoded_str == expected, ( + f"decode_str for {token.name} (id={rank}): expected {expected!r}, got {decoded_str!r}" + ) + decoded_bytes = ll_tokenizer.decode_bytes([rank]) + assert decoded_bytes == expected.encode("utf-8"), ( + f"decode_bytes for {token.name} (id={rank}): " + f"expected {expected.encode('utf-8')!r}, got {decoded_bytes!r}" + ) diff --git a/tests/test_imports.py b/tests/test_imports.py index 4ba25dad..df1de977 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -1,19 +1,23 @@ import builtins from functools import _lru_cache_wrapper from types import ModuleType -from typing import Any, Callable +from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest from mistral_common.imports import ( assert_hf_hub_installed, + assert_jinja2_installed, + assert_llguidance_installed, assert_opencv_installed, assert_package_installed, assert_sentencepiece_installed, assert_soundfile_installed, assert_soxr_installed, is_hf_hub_installed, + is_jinja2_installed, + is_llguidance_installed, is_opencv_installed, is_package_installed, is_sentencepiece_installed, @@ -23,6 +27,8 @@ _IS_INSTALLED_TO_TESTS = [ is_hf_hub_installed, + is_jinja2_installed, + is_llguidance_installed, is_sentencepiece_installed, is_soundfile_installed, is_soxr_installed, @@ -34,6 +40,16 @@ assert_hf_hub_installed, "`huggingface_hub` is not installed. Please install it with `pip install mistral-common[hf-hub]`", ), + ( + is_jinja2_installed, + assert_jinja2_installed, + "`jinja2` is not installed. Please install it with `pip install mistral-common[guidance]`", + ), + ( + is_llguidance_installed, + assert_llguidance_installed, + "`llguidance` is not installed. Please install it with `pip install mistral-common[guidance]`", + ), ( is_opencv_installed, assert_opencv_installed, @@ -117,16 +133,19 @@ def test_is_installed(mock_is_package_installed: MagicMock, is_installed_fn: _lr def test_assert_installed( mock_is_package_installed: MagicMock, is_installed_fn: _lru_cache_wrapper, - assert_fn: Callable[[], None], + assert_fn: _lru_cache_wrapper, error_message: str, ) -> None: is_installed_fn.cache_clear() + assert_fn.cache_clear() mock_is_package_installed.return_value = True assert_fn() is_installed_fn.cache_clear() + assert_fn.cache_clear() mock_is_package_installed.return_value = False with pytest.raises(ImportError) as exc_info: assert_fn() assert str(exc_info.value) == error_message is_installed_fn.cache_clear() + assert_fn.cache_clear() diff --git a/uv.lock b/uv.lock index ef30b985..8b84d3f5 100644 --- a/uv.lock +++ b/uv.lock @@ -759,6 +759,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/01/0e/b27cdbaccf30b890c40ed1da9fd4a3593a5cf94dae54fb34f8a4b74fcd3f/jsonschema_specifications-2025.4.1-py3-none-any.whl", hash = "sha256:4653bffbd6584f7de83a67e0d620ef16900b390ddc7939d56684d6c81e33f1af", size = 18437, upload-time = "2025-04-23T12:34:05.422Z" }, ] +[[package]] +name = "llguidance" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/48/3f7a9d3ff1b36bba92b5107a3a21286821227afe9ea464736133994d61fb/llguidance-1.3.0.tar.gz", hash = "sha256:861249afd51dc325646834462ea827e57a5c2b2042e108e6aae7059fdad9104d", size = 1070460, upload-time = "2025-10-20T19:58:44.164Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/33/be5acb85cd8cdc4afde33d9c234eece9f318e087920255af3c05864cd3e7/llguidance-1.3.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:f7685222660a762e481ac633d49cc559c64980fe2ee59c8f932a5bb5cbc0c2c2", size = 3220647, upload-time = "2025-10-20T19:58:42.542Z" }, + { url = "https://files.pythonhosted.org/packages/82/e6/b48bda5b15efeaeb62bd0dba8fc6a01d4ae5457a85dbb5d18632385fe15c/llguidance-1.3.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:098030ff0687261a3f1bd54cf21fe951fc861d56d37a0671250dd36677eaf224", size = 3099830, upload-time = "2025-10-20T19:58:40.826Z" }, + { url = "https://files.pythonhosted.org/packages/aa/11/44389d3d1526d7a5c38ffd587a5ebc61d7bee443ac1dea95f2089ad58f5f/llguidance-1.3.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f6caca5d78db7f76e1fbb0fff8607b861c32d47fa3d5dee2fc49de27ee269df", size = 2835242, upload-time = "2025-10-20T19:58:34.518Z" }, + { url = "https://files.pythonhosted.org/packages/e7/ca/53ea256396405e4dee70d5a4a35e18543408e18bb16b251d6ca6b5d80310/llguidance-1.3.0-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0612bb3f034d2487b6e8f9561f02a94a6039d88273bf0c5c539a3bd3895e47d2", size = 3297480, upload-time = "2025-10-20T19:58:37.033Z" }, + { url = "https://files.pythonhosted.org/packages/83/a8/1ff2bedb8f9acb46a2d2d603415d272bb622c142ea86f5b95445cc6e366c/llguidance-1.3.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc17e9dd602c3879bf91664a64bf72f54c74dbfbeb24ccfab6a5fe435b12f7aa", size = 3033133, upload-time = "2025-10-20T19:58:38.721Z" }, + { url = "https://files.pythonhosted.org/packages/d7/a7/9b8086c0cfdddf3f6d47b173a404fa7ac46272f7affbee082c36740f4f1c/llguidance-1.3.0-cp39-abi3-win32.whl", hash = "sha256:2f6f558485a43e273fc5c6c974a9a3ace5d5e170076db9b40e0560e41c3ff18f", size = 2598109, upload-time = "2025-10-20T19:58:47.656Z" }, + { url = "https://files.pythonhosted.org/packages/5a/7e/809349638231f469b9056c0e1bfd924d5ef5558b3b3ec72d093b6fad33b1/llguidance-1.3.0-cp39-abi3-win_amd64.whl", hash = "sha256:1d1cd1c8618d1a13605d3e057c978651e551c8c469b481ee4041f1d6c436002d", size = 2789946, upload-time = "2025-10-20T19:58:45.958Z" }, +] + [[package]] name = "markdown" version = "3.8.2" @@ -874,10 +889,27 @@ dependencies = [ ] [package.optional-dependencies] +all = [ + { name = "click" }, + { name = "fastapi", extra = ["standard"] }, + { name = "huggingface-hub" }, + { name = "jinja2" }, + { name = "llguidance" }, + { name = "opencv-python-headless" }, + { name = "pydantic-settings" }, + { name = "sentencepiece" }, + { name = "soundfile" }, + { name = "soxr" }, + { name = "uvloop", version = "0.22.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.14'" }, +] audio = [ { name = "soundfile" }, { name = "soxr" }, ] +guidance = [ + { name = "jinja2" }, + { name = "llguidance" }, +] hf-hub = [ { name = "huggingface-hub" }, ] @@ -934,8 +966,11 @@ requires-dist = [ { name = "click", marker = "extra == 'server'", specifier = ">=8.1.0" }, { name = "fastapi", extras = ["standard"], marker = "extra == 'server'", specifier = ">=0.118.3" }, { name = "huggingface-hub", marker = "extra == 'hf-hub'", specifier = ">=1.0" }, + { name = "jinja2", marker = "extra == 'guidance'", specifier = ">=3.1.0" }, { name = "jsonschema", specifier = ">=4.21.1" }, + { name = "llguidance", marker = "extra == 'guidance'", specifier = ">=1.3.0,<1.4.0" }, { name = "mistral-common", extras = ["opencv"], marker = "extra == 'image'" }, + { name = "mistral-common", extras = ["opencv", "sentencepiece", "audio", "image", "guidance", "hf-hub", "server"], marker = "extra == 'all'" }, { name = "mistral-common", extras = ["soundfile"], marker = "extra == 'audio'" }, { name = "mistral-common", extras = ["soxr"], marker = "extra == 'audio'" }, { name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.25,<2.4" }, @@ -955,7 +990,7 @@ requires-dist = [ { name = "typing-extensions", specifier = ">=4.11.0" }, { name = "uvloop", marker = "python_full_version >= '3.14' and extra == 'server'", specifier = ">=0.22.1" }, ] -provides-extras = ["opencv", "sentencepiece", "soundfile", "soxr", "audio", "image", "hf-hub", "server"] +provides-extras = ["opencv", "sentencepiece", "soundfile", "soxr", "audio", "image", "guidance", "hf-hub", "server", "all"] [package.metadata.requires-dev] dev = [ From 610ea8dd573474de1575a268df97cd4b08ac38cc Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 24 Mar 2026 14:02:46 +0100 Subject: [PATCH 02/21] Refactor lark --- src/mistral_common/guidance/grammar_factory.py | 8 ++++---- tests/guidance/test_guidance.py | 10 ++++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/mistral_common/guidance/grammar_factory.py b/src/mistral_common/guidance/grammar_factory.py index 63621fab..4948c3b7 100644 --- a/src/mistral_common/guidance/grammar_factory.py +++ b/src/mistral_common/guidance/grammar_factory.py @@ -136,9 +136,9 @@ def select_jinja_template(self, reasoning: bool) -> str: jinja_path = JINJA_PATHS[jinja_key] return jinja_path.read_text(encoding="utf-8") - def get_lark( + def get_lark_from_jinja( self, - reasoning: bool, + template: str, mode: ToolChoice, tools: list[Tool] | None, json_schema: dict[str, Any] | None, @@ -147,7 +147,7 @@ def get_lark( r"""Renders a lark grammar from a jinja template. Args: - reasoning: Whether reasoning/thinking mode is enabled. + template: Jinja template to render as a string. mode: The function calling mode (auto, any, none). tools: The list of tools available. json_schema: JSON schema to additionally allow, unioned with the grammar. @@ -156,7 +156,7 @@ def get_lark( Returns: The rendered lark grammar string. """ - jinja_template = Template(self.select_jinja_template(reasoning=reasoning)) + jinja_template = Template(template) lark_grammar = jinja_template.render( mode=mode, fcall=convert_tool_calls(tools, mode, parallel_tool_calls), diff --git a/tests/guidance/test_guidance.py b/tests/guidance/test_guidance.py index 1699c837..30c07022 100644 --- a/tests/guidance/test_guidance.py +++ b/tests/guidance/test_guidance.py @@ -222,8 +222,9 @@ def _find_first_rejection( Raises: ValueError: If all tokens are accepted. """ - grammar = factory.get_lark( - reasoning=False, mode=ToolChoice(mode), tools=tools, json_schema=None, parallel_tool_calls=True + template = factory.select_jinja_template(reasoning=False) + grammar = factory.get_lark_from_jinja( + template=template, mode=ToolChoice(mode), tools=tools, json_schema=None, parallel_tool_calls=True ) matcher = factory.get_matcher(grammar) for i, token in enumerate(tokens): @@ -975,8 +976,9 @@ def test_grammar(self, test_case: TestCase) -> None: elif test_case.json_schema is not None and test_case.tools is None: grammar = factory.get_lark_for_json_schema(json_schema=test_case.json_schema) else: - grammar = factory.get_lark( - reasoning=test_case.reasoning, + template = factory.select_jinja_template(reasoning=test_case.reasoning) + grammar = factory.get_lark_from_jinja( + template=template, mode=ToolChoice(test_case.mode), tools=test_case.tools, json_schema=test_case.json_schema, From b173216839962ac959bed61738e097442e084b1a Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:19:04 +0100 Subject: [PATCH 03/21] Cache part of grammar --- .../guidance/grammar_factory.py | 49 ++++++++++++++----- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/src/mistral_common/guidance/grammar_factory.py b/src/mistral_common/guidance/grammar_factory.py index 4948c3b7..82f6ced4 100644 --- a/src/mistral_common/guidance/grammar_factory.py +++ b/src/mistral_common/guidance/grammar_factory.py @@ -1,5 +1,6 @@ import json from enum import Enum +from functools import lru_cache from pathlib import Path from typing import Any @@ -24,6 +25,34 @@ JINJA_DIR = Path(__file__).parent / "data" +@lru_cache() +def _cached_get_jinja_template(tokenizer_version: TokenizerVersion, reasoning: bool) -> str: + if tokenizer_version < TokenizerVersion.v13: + jinja_key = _GrammarVariant.plain_think if reasoning else _GrammarVariant.base + else: + jinja_key = _GrammarVariant.think if reasoning else _GrammarVariant.base + jinja_path = JINJA_PATHS[jinja_key] + return jinja_path.read_text(encoding="utf-8") + + +@lru_cache() +def _cached_get_lark_from_jinja( + template: str, + mode: ToolChoice, + fcall: str, + json_schema_str: str | None, + parallel_tool_calls: bool, +) -> str: + jinja_template = Template(template) + lark_grammar = jinja_template.render( + mode=mode, + fcall=fcall, + json_schema_str=json_schema_str, + parallel_tool_calls=parallel_tool_calls, + ) + return lark_grammar + + class _GrammarVariant(str, Enum): base = "base" plain_think = "plain_think" @@ -128,13 +157,7 @@ def select_jinja_template(self, reasoning: bool) -> str: Returns: The jinja template content as a string. """ - tokenizer_version = self._tokenizer.version - if tokenizer_version < TokenizerVersion.v13: - jinja_key = _GrammarVariant.plain_think if reasoning else _GrammarVariant.base - else: - jinja_key = _GrammarVariant.think if reasoning else _GrammarVariant.base - jinja_path = JINJA_PATHS[jinja_key] - return jinja_path.read_text(encoding="utf-8") + return _cached_get_jinja_template(tokenizer_version=self._tokenizer.version, reasoning=reasoning) def get_lark_from_jinja( self, @@ -156,13 +179,15 @@ def get_lark_from_jinja( Returns: The rendered lark grammar string. """ - jinja_template = Template(template) - lark_grammar = jinja_template.render( + fcall = convert_tool_calls(tools, mode, parallel_tool_calls) + json_schema_str = json.dumps(json_schema, ensure_ascii=False) if json_schema else None + return _cached_get_lark_from_jinja( + template=template, mode=mode, - fcall=convert_tool_calls(tools, mode, parallel_tool_calls), - json_schema_str=json.dumps(json_schema, ensure_ascii=False) if json_schema else None, + tools=fcall, + json_schema_str=json_schema_str, + parallel_tool_calls=parallel_tool_calls, ) - return lark_grammar def get_lark_for_json_schema(self, json_schema: dict[str, Any]) -> str: r"""Returns a lark grammar that only accepts JSON objects matching the given schema.""" From 84598af8d046cd8de5a97c39d7b216ccd7e569f7 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 24 Mar 2026 23:32:04 +0100 Subject: [PATCH 04/21] Add required and NamedToolChoice support --- .../guidance/data/base_grammar.lark.jinja | 2 + .../data/plain_text_think_grammar.lark.jinja | 2 +- .../guidance/data/think_grammar.lark.jinja | 2 +- .../guidance/grammar_factory.py | 41 +- tests/guidance/test_guidance.py | 436 +++++++++++++++++- 5 files changed, 443 insertions(+), 40 deletions(-) diff --git a/src/mistral_common/guidance/data/base_grammar.lark.jinja b/src/mistral_common/guidance/data/base_grammar.lark.jinja index a1940f4a..b43b1998 100644 --- a/src/mistral_common/guidance/data/base_grammar.lark.jinja +++ b/src/mistral_common/guidance/data/base_grammar.lark.jinja @@ -8,6 +8,8 @@ start: body body: content | (content? fcalls) {% elif mode == "any" -%} body: fcalls +{% elif mode == "required" -%} +body: content? fcalls {% elif mode == "none" -%} body: content {% endif -%} diff --git a/src/mistral_common/guidance/data/plain_text_think_grammar.lark.jinja b/src/mistral_common/guidance/data/plain_text_think_grammar.lark.jinja index 86b7230a..489d7271 100644 --- a/src/mistral_common/guidance/data/plain_text_think_grammar.lark.jinja +++ b/src/mistral_common/guidance/data/plain_text_think_grammar.lark.jinja @@ -6,7 +6,7 @@ start: body {% if mode == "auto" -%} body: think (content | fcalls) -{% elif mode == "any" -%} +{% elif mode == "any" or mode == "required" -%} body: think fcalls {% elif mode == "none" -%} body: think content diff --git a/src/mistral_common/guidance/data/think_grammar.lark.jinja b/src/mistral_common/guidance/data/think_grammar.lark.jinja index 59f16a71..6acc20e6 100644 --- a/src/mistral_common/guidance/data/think_grammar.lark.jinja +++ b/src/mistral_common/guidance/data/think_grammar.lark.jinja @@ -6,7 +6,7 @@ start: body {% if mode == "auto" -%} body: think? (content | fcalls) -{% elif mode == "any" -%} +{% elif mode == "any" or mode == "required" -%} body: think? fcalls {% elif mode == "none" -%} body: think? content diff --git a/src/mistral_common/guidance/grammar_factory.py b/src/mistral_common/guidance/grammar_factory.py index 82f6ced4..dc4e884b 100644 --- a/src/mistral_common/guidance/grammar_factory.py +++ b/src/mistral_common/guidance/grammar_factory.py @@ -11,7 +11,7 @@ is_jinja2_installed, is_llguidance_installed, ) -from mistral_common.protocol.instruct.tool_calls import Tool, ToolChoice +from mistral_common.protocol.instruct.tool_calls import NamedToolChoice, Tool, ToolChoice, ToolChoiceEnum from mistral_common.tokens.tokenizers.base import TokenizerVersion from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.tekken import is_tekkenizer @@ -38,7 +38,7 @@ def _cached_get_jinja_template(tokenizer_version: TokenizerVersion, reasoning: b @lru_cache() def _cached_get_lark_from_jinja( template: str, - mode: ToolChoice, + mode: str, fcall: str, json_schema_str: str | None, parallel_tool_calls: bool, @@ -95,19 +95,28 @@ def convert_tool_calls( any_strict_true = any(tool.function.strict for tool in tools) if tools else False if not tools or not any_strict_true: - single_tool_call = ' SAFE_WS? /.+/ SAFE_WS? %json {"type": "object"} SAFE_WS?' - return f"({single_tool_call})+" if parallel_tool_calls else single_tool_call - - individual_tool_calls = [] - for tool in tools: - args = _get_tool_args_json(tool) - individual_tool_calls.append( - f'( SAFE_WS? "{tool.function.name}" SAFE_WS? %json ' - f"{json.dumps(args, ensure_ascii=False)} SAFE_WS?)" + if not isinstance(mode, NamedToolChoice): + grammar_tool_call = ' SAFE_WS? /.+/ SAFE_WS? %json {"type": "object"} SAFE_WS?' + else: + grammar_tool_call = ( + f' SAFE_WS? "{mode.function.name}" SAFE_WS? %json {{"type": "object"}} SAFE_WS?' + ) + else: + grammar_per_tool = [] + tools = ( + [next(tool for tool in tools if tool.function.name == mode.function.name)] + if isinstance(mode, NamedToolChoice) + else tools ) + for tool in tools: + args = _get_tool_args_json(tool) + grammar_per_tool.append( + f'( SAFE_WS? "{tool.function.name}" SAFE_WS? %json ' + f"{json.dumps(args, ensure_ascii=False)} SAFE_WS?)" + ) + grammar_tool_call = f"{' | '.join(grammar_per_tool)}" - single_tool_call = f"{' | '.join(individual_tool_calls)}" - return f"({single_tool_call})+" if parallel_tool_calls else single_tool_call + return f"({grammar_tool_call})+" if parallel_tool_calls else grammar_tool_call class GrammarFactory: @@ -181,10 +190,12 @@ def get_lark_from_jinja( """ fcall = convert_tool_calls(tools, mode, parallel_tool_calls) json_schema_str = json.dumps(json_schema, ensure_ascii=False) if json_schema else None + # NamedToolChoice forces a specific tool, which maps to "required" grammar + template_mode = ToolChoiceEnum.required if isinstance(mode, NamedToolChoice) else mode return _cached_get_lark_from_jinja( template=template, - mode=mode, - tools=fcall, + mode=template_mode.value, + fcall=fcall, json_schema_str=json_schema_str, parallel_tool_calls=parallel_tool_calls, ) diff --git a/tests/guidance/test_guidance.py b/tests/guidance/test_guidance.py index 30c07022..7df13fd5 100644 --- a/tests/guidance/test_guidance.py +++ b/tests/guidance/test_guidance.py @@ -10,7 +10,17 @@ from mistral_common.protocol.instruct.chunk import TextChunk, ThinkChunk from mistral_common.protocol.instruct.messages import AssistantMessage from mistral_common.protocol.instruct.normalize import get_normalizer -from mistral_common.protocol.instruct.tool_calls import Function, FunctionCall, Tool, ToolCall, ToolChoice +from mistral_common.protocol.instruct.tool_calls import ( + Function, + FunctionCall, + FunctionName, + NamedToolChoice, + Tool, + ToolCall, + ToolChoice, + ToolChoiceEnum, + ToolTypes, +) from mistral_common.protocol.instruct.validator import ValidationMode, get_validator from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, Tokenizer, TokenizerVersion from mistral_common.tokens.tokenizers.instruct import ( @@ -24,8 +34,9 @@ EMOJI_LARK_PATH = Path(__file__).parent.parent / "data" / "emoji.lark" -Mode = Literal["auto", "any", "none"] +Mode = Literal["auto", "any", "none", "required"] _AUTO_ANY: tuple[Mode, Mode] = ("auto", "any") +_REQUIRED_MODE: Mode = "required" _NUM_SPECIAL_TOKENS = 100 @@ -130,7 +141,7 @@ class TestCase(BaseModel): __test__ = False model_config = ConfigDict(arbitrary_types_allowed=True) tokenizer: Tokenizer - mode: Literal["auto", "any", "none"] + mode: Literal["auto", "any", "none", "required"] | NamedToolChoice tokens: list[int] should_fail_on: int | None case_name: str @@ -205,7 +216,7 @@ def _encode_content( def _find_first_rejection( factory: GrammarFactory, tokens: list[int], - mode: Mode, + mode: Literal["auto", "any", "none", "required"] | NamedToolChoice, tools: list[Tool] | None, ) -> int: r"""Finds the index of the first token rejected by the grammar. @@ -213,7 +224,7 @@ def _find_first_rejection( Args: factory: The grammar factory. tokens: The token sequence to test. - mode: The tool choice mode. + mode: The tool choice mode (literal or NamedToolChoice). tools: The tools to pass to grammar generation. Returns: @@ -223,9 +234,14 @@ def _find_first_rejection( ValueError: If all tokens are accepted. """ template = factory.select_jinja_template(reasoning=False) - grammar = factory.get_lark_from_jinja( - template=template, mode=ToolChoice(mode), tools=tools, json_schema=None, parallel_tool_calls=True - ) + if isinstance(mode, NamedToolChoice): + grammar = factory.get_lark_from_jinja( + template=template, mode=mode, tools=tools, json_schema=None, parallel_tool_calls=True + ) + else: + grammar = factory.get_lark_from_jinja( + template=template, mode=ToolChoiceEnum(mode), tools=tools, json_schema=None, parallel_tool_calls=True + ) matcher = factory.get_matcher(grammar) for i, token in enumerate(tokens): if not matcher.consume_token(token): @@ -384,6 +400,18 @@ def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[Test mode=mode, ) ) + # The "required" mode should behave like "any" - tool calls are required + # and text content is optional (content? fcalls) + if case_name != "single_non_strict_tool_call": # handled in strict tests + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=valid_for.get("any"), + case_name=case_name, + mode="required", + ) + ) # v11: plain_text_think mandates first, so bare tool calls fail at 0 # v13+: think grammar has think? (optional), tool calls pass; fcalls allows content? prefix @@ -402,6 +430,21 @@ def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[Test reasoning=True, ) ) + # "required" mode with reasoning + if tokenizer.version < TokenizerVersion.v13: + reasoning_required_fail = 0 + else: + reasoning_required_fail = None + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=reasoning_required_fail, + case_name=f"{case_name}_reasoning", + mode="required", + reasoning=True, + ) + ) # Broken / missing args edge cases (token-level construction) token_items: list[tuple[str, list[int], dict[Mode, int | None]]] = [ @@ -414,7 +457,7 @@ def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[Test *tokenizer.encode('{"a', bos=False, eos=False), tokenizer.eos_id, ], - {"auto": -1, "any": -1, "none": 0}, + {"auto": -1, "any": -1, "none": 0, "required": -1}, ), ( "fcall_missing_args", @@ -424,7 +467,7 @@ def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[Test tokenizer.get_special_token("[ARGS]"), tokenizer.get_special_token("[TOOL_CALLS]"), ], - {"auto": -1, "any": -1, "none": 0}, + {"auto": -1, "any": -1, "none": 0, "required": -1}, ), ] @@ -455,8 +498,10 @@ def _generate_cases_text_and_tool_calls(mistral_tokenizer: MistralTokenizer) -> tokens = _encode_content(instruct_tokenizer, content) text_len = len(tokenizer.encode("Hello!", bos=False, eos=False)) - # Non-reasoning uses base grammar where "any" mode is `body: fcalls` β€” no content allowed - valid_for: dict[Mode, int | None] = {"auto": None, "any": 0, "none": text_len} + # Non-reasoning uses base grammar where: + # - "any" mode is `body: fcalls` β€” no content allowed + # - "required" mode is `body: content? fcalls` β€” content is optional + valid_for: dict[Mode, int | None] = {"auto": None, "any": 0, "none": text_len, "required": None} for mode, should_fail_on in valid_for.items(): cases.append( @@ -472,9 +517,9 @@ def _generate_cases_text_and_tool_calls(mistral_tokenizer: MistralTokenizer) -> # v11: plain_text_think mandates first, so text+fcall without think fails at 0 # v13+: think? optional, fcalls: content? fcall, so "any" mode accepts text before tool calls if tokenizer.version < TokenizerVersion.v13: - reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "any": 0, "none": 0} + reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "any": 0, "none": 0, "required": 0} else: - reasoning_valid_for = {"auto": None, "any": None, "none": text_len} + reasoning_valid_for = {"auto": None, "any": None, "none": text_len, "required": None} for mode, should_fail_on in reasoning_valid_for.items(): cases.append( TestCase( @@ -501,25 +546,27 @@ def _generate_cases_thinking_v11(mistral_tokenizer: MistralTokenizer) -> list[Te ( "force_think", [TextChunk(text="Hello world!")], - {"auto": 0, "any": 0, "none": 0}, + {"auto": 0, "any": 0, "none": 0, "required": 0}, ), ( "think_without_response", [TextChunk(text="Hello!")], - {"auto": -1, "any": -1, "none": -1}, + {"auto": -1, "any": -1, "none": -1, "required": -1}, ), ( "unclosed_think", [TextChunk(text="Hello!")], - {"auto": -1, "any": -1, "none": -1}, + {"auto": -1, "any": -1, "none": -1, "required": -1}, ), ( "plain_think_with_response", [TextChunk(text="Hello!World!")], { "auto": None, + # any/required: think fcalls β€” after think, "World!" doesn't match fcalls "any": len(tokenizer.encode("Hello!", bos=False, eos=False)), "none": None, + "required": len(tokenizer.encode("Hello!", bos=False, eos=False)), }, ), ( @@ -532,6 +579,7 @@ def _generate_cases_thinking_v11(mistral_tokenizer: MistralTokenizer) -> list[Te "auto": None, "any": None, "none": len(tokenizer.encode("Hello!", bos=False, eos=False)), + "required": None, }, ), ( @@ -541,9 +589,12 @@ def _generate_cases_thinking_v11(mistral_tokenizer: MistralTokenizer) -> list[Te ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}')), ], { + # auto: think (content | fcalls) β€” picks content for "Ho!", then tool call rejected "auto": len(tokenizer.encode("Hello!Ho!", bos=False, eos=False)), + # any/required: think fcalls β€” after think, "Ho!" doesn't match fcalls "any": len(tokenizer.encode("Hello!", bos=False, eos=False)), "none": len(tokenizer.encode("Hello!Ho!", bos=False, eos=False)), + "required": len(tokenizer.encode("Hello!", bos=False, eos=False)), }, ), ] @@ -580,17 +631,17 @@ def _think_tokens(text: str) -> list[int]: ( "plain_text", [TextChunk(text="Hello world!")], - {"auto": None, "any": -1, "none": None}, + {"auto": None, "any": -1, "none": None, "required": -1}, ), ( "plain_think", [ThinkChunk(thinking="Hello!")], - {"auto": -1, "any": -1, "none": -1}, + {"auto": -1, "any": -1, "none": -1, "required": -1}, ), ( "plain_think_with_response", [ThinkChunk(thinking="Hello!"), TextChunk(text="World!")], - {"auto": None, "any": -1, "none": None}, + {"auto": None, "any": -1, "none": None, "required": -1}, ), ( "think_with_tool_call", @@ -602,6 +653,7 @@ def _think_tokens(text: str) -> list[int]: "auto": None, "any": None, "none": len(_think_tokens("Hello!")), + "required": None, }, ), ( @@ -615,6 +667,8 @@ def _think_tokens(text: str) -> list[int]: "auto": None, "any": None, "none": len(_think_tokens("Hello!")) + len(tokenizer.encode("World!", bos=False, eos=False)), + # required: think? content? fcalls β€” think+content+fcalls all present, passes + "required": None, }, ), ] @@ -653,6 +707,17 @@ def _generate_single_tool_call(mistral_tokenizer: MistralTokenizer) -> list[Test parallel_tool_calls=False, ) ) + # "required" mode also accepts single tool call + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=single_tokens, + should_fail_on=None, + case_name="single_tool_call", + mode="required", + parallel_tool_calls=False, + ) + ) cases.append( TestCase( tokenizer=tokenizer, @@ -667,9 +732,9 @@ def _generate_single_tool_call(mistral_tokenizer: MistralTokenizer) -> list[Test # v11: plain_text_think mandates first, so bare tool calls fail at 0 # v13+: think? optional, tool calls pass if tokenizer.version < TokenizerVersion.v13: - reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "any": 0, "none": 0} + reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "any": 0, "none": 0, "required": 0} else: - reasoning_valid_for = {"auto": None, "any": None, "none": 0} + reasoning_valid_for = {"auto": None, "any": None, "none": 0, "required": None} for mode, should_fail_on in reasoning_valid_for.items(): cases.append( TestCase( @@ -705,6 +770,17 @@ def _generate_single_tool_call(mistral_tokenizer: MistralTokenizer) -> list[Test parallel_tool_calls=False, ) ) + # "required" mode with parallel_tool_calls=False also fails on second tool call + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=multi_tokens, + should_fail_on=fail_idx, + case_name="multi_tool_call_disallowed", + mode="required", + parallel_tool_calls=False, + ) + ) return cases @@ -731,6 +807,17 @@ def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: Gr tools=[ToolProvider.retrieve_payment_date(strict=False)], ) ) + # required mode also accepts non-strict tool calls + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=non_strict_tokens, + should_fail_on=None, + case_name="single_non_strict_tool_call", + mode="required", + tools=[ToolProvider.retrieve_payment_date(strict=False)], + ) + ) # 2. Correct strict tool call strict_call = [ @@ -748,6 +835,17 @@ def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: Gr tools=[ToolProvider.retrieve_payment_date(strict=True)], ) ) + # required mode also accepts correct strict tool call + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=strict_tokens, + should_fail_on=None, + case_name="single_strict_tool_call", + mode="required", + tools=[ToolProvider.retrieve_payment_date(strict=True)], + ) + ) # 3. Wrong args for strict tool β€” it must fail somewhere before the end. wrong_args_call = [ToolCall(function=FunctionCall(name="retrieve_payment_date", arguments='{"bogus": "12345"}'))] @@ -764,6 +862,17 @@ def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: Gr tools=tools_strict, ) ) + # required mode also fails on wrong args + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=wrong_args_tokens, + should_fail_on=bogus_start, + case_name="strict_tool_call_wrong_args", + mode="required", + tools=tools_strict, + ) + ) # 4. Wrong name for strict tool wrong_name_call = [ToolCall(function=FunctionCall(name="fn1", arguments='{"transaction_id": "12345"}'))] @@ -780,6 +889,17 @@ def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: Gr tools=tools_strict, ) ) + # required mode also fails on wrong name + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=wrong_name_tokens, + should_fail_on=fail_on_name, + case_name="strict_tool_call_wrong_name", + mode="required", + tools=tools_strict, + ) + ) # 5. Multiple strict tool calls (both correct) multi_strict = [ @@ -801,6 +921,20 @@ def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: Gr ], ) ) + # required mode also accepts multiple strict tool calls + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=multi_strict_tokens, + should_fail_on=None, + case_name="multiple_strict_tool_calls", + mode="required", + tools=[ + ToolProvider.retrieve_payment_date(strict=True), + ToolProvider.retrieve_payment_status(strict=True), + ], + ) + ) # 6. reasoning=True variants # v11: plain_text_think mandates first, so all bare tool calls fail at 0 @@ -818,6 +952,18 @@ def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: Gr reasoning=True, ) ) + # required mode with reasoning in v11 also fails at 0 + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=strict_tokens, + should_fail_on=0, + case_name="strict_tool_call_reasoning", + mode="required", + tools=tools_strict, + reasoning=True, + ) + ) else: for mode in _AUTO_ANY: cases.append( @@ -831,6 +977,18 @@ def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: Gr reasoning=True, ) ) + # required mode with reasoning in v13+ also accepts strict tool call + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=strict_tokens, + should_fail_on=None, + case_name="strict_tool_call_reasoning", + mode="required", + tools=tools_strict, + reasoning=True, + ) + ) for mode in _AUTO_ANY: cases.append( TestCase( @@ -846,6 +1004,232 @@ def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: Gr reasoning=True, ) ) + # required mode with reasoning in v13+ also accepts multiple strict tool calls + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=multi_strict_tokens, + should_fail_on=None, + case_name="multiple_strict_tool_calls_reasoning", + mode="required", + tools=[ + ToolProvider.retrieve_payment_date(strict=True), + ToolProvider.retrieve_payment_status(strict=True), + ], + reasoning=True, + ) + ) + + return cases + + +def _generate_named_tool_choice(mistral_tokenizer: MistralTokenizer, factory: GrammarFactory) -> list[TestCase]: + r"""Generate test cases for NamedToolChoice (forcing a specific tool).""" + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + + # Define the tools we'll use + tools = [ + ToolProvider.retrieve_payment_date(strict=False), + ToolProvider.retrieve_payment_status(strict=False), + ] + + # 1. NamedToolChoice for retrieve_payment_date with correct function call + named_tool_date = NamedToolChoice( + type=ToolTypes.function, + function=FunctionName(name="retrieve_payment_date"), + ) + correct_date_call = [ + ToolCall( + function=FunctionCall( + name="retrieve_payment_date", + arguments='{"transaction_id": "12345"}', + ) + ) + ] + correct_date_tokens = _encode_content(instruct_tokenizer, correct_date_call) + + # The named tool choice should accept only that specific tool + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=correct_date_tokens, + should_fail_on=None, + case_name="named_tool_choice_correct", + mode=named_tool_date, + tools=tools, + parallel_tool_calls=True, + ) + ) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=correct_date_tokens, + should_fail_on=None, + case_name="named_tool_choice_correct", + mode=named_tool_date, + tools=tools, + parallel_tool_calls=False, + ) + ) + + # 2. Non-strict NamedToolChoice should NOT enforce JSON arguments schema β€” + # any valid JSON object is accepted (uses %json {"type": "object"}) + arbitrary_args_call = [ + ToolCall( + function=FunctionCall( + name="retrieve_payment_date", + arguments='{"completely": "arbitrary", "keys": [1, 2, 3]}', + ) + ) + ] + arbitrary_args_tokens = _encode_content(instruct_tokenizer, arbitrary_args_call) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=arbitrary_args_tokens, + should_fail_on=None, + case_name="named_tool_choice_non_strict_arbitrary_args", + mode=named_tool_date, + tools=tools, + parallel_tool_calls=True, + ) + ) + + # 3. NamedToolChoice should reject a different tool name + wrong_tool_call = [ + ToolCall( + function=FunctionCall( + name="retrieve_payment_status", + arguments='{"transaction_id": "12345"}', + ) + ) + ] + wrong_tool_tokens = _encode_content(instruct_tokenizer, wrong_tool_call) + + # Find where the rejection happens + fail_idx = _find_first_rejection( + factory, + wrong_tool_tokens, + mode=named_tool_date, + tools=tools, + ) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=wrong_tool_tokens, + should_fail_on=fail_idx, + case_name="named_tool_choice_wrong_tool", + mode=named_tool_date, + tools=tools, + parallel_tool_calls=True, + ) + ) + + # 4. NamedToolChoice with reasoning mode + # v11: plain_text_think mandates first + # v13+: think? optional + if tokenizer.version < TokenizerVersion.v13: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=correct_date_tokens, + should_fail_on=0, + case_name="named_tool_choice_reasoning", + mode=named_tool_date, + tools=tools, + reasoning=True, + ) + ) + else: + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=correct_date_tokens, + should_fail_on=None, + case_name="named_tool_choice_reasoning", + mode=named_tool_date, + tools=tools, + reasoning=True, + ) + ) + + # 5. NamedToolChoice with strict tool should validate arguments + strict_tools = [ToolProvider.retrieve_payment_date(strict=True)] + named_tool_strict = NamedToolChoice( + type=ToolTypes.function, + function=FunctionName(name="retrieve_payment_date"), + ) + + # Correct args + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=correct_date_tokens, + should_fail_on=None, + case_name="named_tool_choice_strict_correct", + mode=named_tool_strict, + tools=strict_tools, + ) + ) + + # Wrong args + wrong_args_call = [ + ToolCall( + function=FunctionCall( + name="retrieve_payment_date", + arguments='{"wrong_arg": "12345"}', + ) + ) + ] + wrong_args_tokens = _encode_content(instruct_tokenizer, wrong_args_call) + fail_on_args = _find_first_rejection( + factory, + wrong_args_tokens, + mode=named_tool_strict, + tools=strict_tools, + ) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=wrong_args_tokens, + should_fail_on=fail_on_args, + case_name="named_tool_choice_strict_wrong_args", + mode=named_tool_strict, + tools=strict_tools, + ) + ) + + # 6. NamedToolChoice with non-existent tool in tools list + # This should generate a grammar that only accepts the named tool + named_nonexistent = NamedToolChoice( + type=ToolTypes.function, + function=FunctionName(name="non_existent_tool"), + ) + nonexistent_call = [ + ToolCall( + function=FunctionCall( + name="non_existent_tool", + arguments='{"arg": "value"}', + ) + ) + ] + nonexistent_tokens = _encode_content(instruct_tokenizer, nonexistent_call) + + # Without strict tool, any args are accepted + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=nonexistent_tokens, + should_fail_on=None, + case_name="named_tool_choice_nonexistent_tool", + mode=named_nonexistent, + tools=[], # No tools provided + ) + ) return cases @@ -931,6 +1315,7 @@ def _generate_cases(mistral_tokenizer: MistralTokenizer, factory: GrammarFactory cases += _generate_cases_tool_calls(mistral_tokenizer) cases += _generate_single_tool_call(mistral_tokenizer) cases += _generate_strict_tool_calls(mistral_tokenizer, factory) + cases += _generate_named_tool_choice(mistral_tokenizer, factory) cases += _generate_cases_text_and_tool_calls(mistral_tokenizer) if not tokenizer_version < TokenizerVersion.v13: @@ -977,9 +1362,14 @@ def test_grammar(self, test_case: TestCase) -> None: grammar = factory.get_lark_for_json_schema(json_schema=test_case.json_schema) else: template = factory.select_jinja_template(reasoning=test_case.reasoning) + resolved_mode: ToolChoice + if isinstance(test_case.mode, NamedToolChoice): + resolved_mode = test_case.mode + else: + resolved_mode = ToolChoiceEnum(test_case.mode) grammar = factory.get_lark_from_jinja( template=template, - mode=ToolChoice(test_case.mode), + mode=resolved_mode, tools=test_case.tools, json_schema=test_case.json_schema, parallel_tool_calls=test_case.parallel_tool_calls, From 0d14390ed70f4ab12a15194c85e14482232a1d3f Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Wed, 25 Mar 2026 10:26:51 +0100 Subject: [PATCH 05/21] Expose llg tokenizer --- src/mistral_common/guidance/grammar_factory.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/mistral_common/guidance/grammar_factory.py b/src/mistral_common/guidance/grammar_factory.py index dc4e884b..edb10d56 100644 --- a/src/mistral_common/guidance/grammar_factory.py +++ b/src/mistral_common/guidance/grammar_factory.py @@ -157,6 +157,10 @@ def __init__(self, tokenizer: MistralTokenizer) -> None: ) self._llg_tokenizer = from_mistral_tokenizer(tokenizer) + @property + def llg_tokenizer(self) -> "llg.LLTokenizer": + return self._llg_tokenizer + def select_jinja_template(self, reasoning: bool) -> str: r"""Selects and returns the appropriate jinja template content based on tokenizer version and reasoning mode. From cc5990bb800403bdb9026b4478aa0f839735534b Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Wed, 25 Mar 2026 19:50:49 +0100 Subject: [PATCH 06/21] Refactor tests to reduce LOC and handle json + emojis --- .../guidance/grammar_factory.py | 2 +- tests/guidance/test_guidance.py | 422 +++++++----------- 2 files changed, 164 insertions(+), 260 deletions(-) diff --git a/src/mistral_common/guidance/grammar_factory.py b/src/mistral_common/guidance/grammar_factory.py index edb10d56..08b186b1 100644 --- a/src/mistral_common/guidance/grammar_factory.py +++ b/src/mistral_common/guidance/grammar_factory.py @@ -195,7 +195,7 @@ def get_lark_from_jinja( fcall = convert_tool_calls(tools, mode, parallel_tool_calls) json_schema_str = json.dumps(json_schema, ensure_ascii=False) if json_schema else None # NamedToolChoice forces a specific tool, which maps to "required" grammar - template_mode = ToolChoiceEnum.required if isinstance(mode, NamedToolChoice) else mode + template_mode = ToolChoiceEnum.required if isinstance(mode, NamedToolChoice) else ToolChoiceEnum(mode) return _cached_get_lark_from_jinja( template=template, mode=template_mode.value, diff --git a/tests/guidance/test_guidance.py b/tests/guidance/test_guidance.py index 7df13fd5..73feaa1d 100644 --- a/tests/guidance/test_guidance.py +++ b/tests/guidance/test_guidance.py @@ -35,12 +35,46 @@ EMOJI_LARK_PATH = Path(__file__).parent.parent / "data" / "emoji.lark" Mode = Literal["auto", "any", "none", "required"] -_AUTO_ANY: tuple[Mode, Mode] = ("auto", "any") -_REQUIRED_MODE: Mode = "required" +_AUTO_ANY_REQUIRED: tuple[Mode, Mode, Mode] = ("auto", "any", "required") _NUM_SPECIAL_TOKENS = 100 -_EXTRA_TOKENS = [b"a", b"b", b"c", b"f", b"de", b"he", b"llo"] +_EXTRA_TOKENS = [ + b"de", + b"he", + b"llo", + "πŸ˜ƒ".encode("utf-8"), + "πŸ˜‚".encode("utf-8"), + "😊".encode("utf-8"), + "😍".encode("utf-8"), + "😘".encode("utf-8"), + "πŸ˜—".encode("utf-8"), + "πŸ˜™".encode("utf-8"), + "😚".encode("utf-8"), + "πŸ˜‹".encode("utf-8"), + "πŸ˜›".encode("utf-8"), + "😜".encode("utf-8"), + "😝".encode("utf-8"), + "πŸ€‘".encode("utf-8"), + "πŸ€—".encode("utf-8"), + "πŸ€”".encode("utf-8"), + "🀐".encode("utf-8"), + "😐".encode("utf-8"), + "πŸ˜‘".encode("utf-8"), + "😢".encode("utf-8"), + "😬".encode("utf-8"), + "こ".encode("utf-8"), + "γ‚“".encode("utf-8"), + "に".encode("utf-8"), + "け".encode("utf-8"), + "は".encode("utf-8"), + "Ω…Ψ±Ψ­Ψ¨Ψ§".encode("utf-8"), + "Ψ¨ΩƒΩ…".encode("utf-8"), + "في".encode("utf-8"), + "ΨΉΨ§Ω„Ω…".encode("utf-8"), + "Ψ§Ω„Ψ°ΩƒΨ§Ψ‘".encode("utf-8"), + "Ψ§Ω„Ψ§Ψ΅Ψ·Ω†Ψ§ΨΉΩŠ".encode("utf-8"), +] def _build_tekken_mistral_tokenizer( @@ -49,12 +83,13 @@ def _build_tekken_mistral_tokenizer( ) -> MistralTokenizer: r"""Builds a MistralTokenizer wrapping a programmatic Tekkenizer.""" special_tokens = get_special_tokens(version, add_think=add_think) + vocab = quick_vocab(_EXTRA_TOKENS) tekkenizer = Tekkenizer( - quick_vocab(_EXTRA_TOKENS), + vocab, special_tokens=special_tokens, - pattern=r".+", - vocab_size=256 + _NUM_SPECIAL_TOKENS, + pattern=r"(?s:.+)", + vocab_size=len(vocab) + _NUM_SPECIAL_TOKENS, num_special_tokens=_NUM_SPECIAL_TOKENS, version=version, ) @@ -172,15 +207,6 @@ def _encode_content( instruct_tokenizer: InstructTokenizerBase, content: str | list[Any], ) -> list[int]: - r"""Encodes assistant message content into token ids. - - Args: - instruct_tokenizer: The instruct tokenizer to use. - content: Either a plain string or a list of TextChunk / ThinkChunk / ToolCall. - - Returns: - The encoded token ids. - """ tokenizer = instruct_tokenizer.tokenizer if isinstance(content, str): @@ -216,7 +242,7 @@ def _encode_content( def _find_first_rejection( factory: GrammarFactory, tokens: list[int], - mode: Literal["auto", "any", "none", "required"] | NamedToolChoice, + mode: ToolChoice, tools: list[Tool] | None, ) -> int: r"""Finds the index of the first token rejected by the grammar. @@ -234,14 +260,9 @@ def _find_first_rejection( ValueError: If all tokens are accepted. """ template = factory.select_jinja_template(reasoning=False) - if isinstance(mode, NamedToolChoice): - grammar = factory.get_lark_from_jinja( - template=template, mode=mode, tools=tools, json_schema=None, parallel_tool_calls=True - ) - else: - grammar = factory.get_lark_from_jinja( - template=template, mode=ToolChoiceEnum(mode), tools=tools, json_schema=None, parallel_tool_calls=True - ) + grammar = factory.get_lark_from_jinja( + template=template, mode=mode, tools=tools, json_schema=None, parallel_tool_calls=True + ) matcher = factory.get_matcher(grammar) for i, token in enumerate(tokens): if not matcher.consume_token(token): @@ -250,7 +271,6 @@ def _find_first_rejection( def _token_debug_repr(tokenizer: Tekkenizer, token_id: int) -> str: - r"""Returns a debug representation of a token.""" return repr(tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)) @@ -258,26 +278,18 @@ def _generate_general_cases(mistral_tokenizer: MistralTokenizer) -> list[TestCas instruct_tokenizer = mistral_tokenizer.instruct_tokenizer tokenizer = instruct_tokenizer.tokenizer assert isinstance(instruct_tokenizer, InstructTokenizerBase) + assert isinstance(tokenizer, Tekkenizer) cases: list[TestCase] = [] - # For programmatic tiny-vocab tokenizers, limit to ASCII-safe content - is_full_vocab = tokenizer.n_words > 1000 - if is_full_vocab: - items = { - "newline": "\n", - "blank": "_", - "text": "Hello!", - "text_with_newlines": "Hello!\n\nHow are you?\nI'm fine, thanks!", - "emojis": "πŸ˜ƒπŸ˜‚πŸ˜ŠπŸ˜πŸ˜˜πŸ˜—πŸ˜™πŸ˜šπŸ˜‹πŸ˜›πŸ˜œπŸ˜πŸ€‘πŸ€—πŸ€”πŸ€πŸ˜πŸ˜‘πŸ˜ΆπŸ˜¬", - "japanese": "こんにけは", - "arabic": "Ω…Ψ±Ψ­Ψ¨Ψ§ Ψ¨ΩƒΩ… في ΨΉΨ§Ω„Ω… Ψ§Ω„Ψ°ΩƒΨ§Ψ‘ Ψ§Ω„Ψ§Ψ΅Ψ·Ω†Ψ§ΨΉΩŠ", - } - else: - items = { - "text_a": "abc", - "text_b": "hello", - "text_de": "de", - } + items = { + "newline": "\n", + "blank": "_", + "text": "Hello!", + "text_with_newlines": "Hello!\n\nHow are you?\nI'm fine, thanks!", + "emojis": "πŸ˜ƒπŸ˜‚πŸ˜ŠπŸ˜πŸ˜˜πŸ˜—πŸ˜™πŸ˜šπŸ˜‹πŸ˜›πŸ˜œπŸ˜πŸ€‘πŸ€—πŸ€”πŸ€πŸ˜πŸ˜‘πŸ˜ΆπŸ˜¬", + "japanese": "こんにけは", + "arabic": "Ω…Ψ±Ψ­Ψ¨Ψ§ Ψ¨ΩƒΩ… في ΨΉΨ§Ω„Ω… Ψ§Ω„Ψ°ΩƒΨ§Ψ‘ Ψ§Ω„Ψ§Ψ΅Ψ·Ω†Ψ§ΨΉΩŠ", + } for case_name, content in items.items(): tokens = _encode_content(instruct_tokenizer, content) cases.append( @@ -290,9 +302,20 @@ def _generate_general_cases(mistral_tokenizer: MistralTokenizer) -> list[TestCas reasoning=False, ) ) - # v11: plain_text_think grammar requires ... first, so plain text fails - # v13+: think grammar has think? (optional), so plain text passes - reasoning_fail = 0 if tokenizer.version < TokenizerVersion.v13 else None + if tokenizer.version < TokenizerVersion.v13: + # Count how many leading whitespace-only tokens the SAFE_WS? rule will consume + # before the grammar rejects (expecting ). + content_tokens = tokenizer.encode(content, bos=False, eos=False) + ws_prefix_len = 0 + for t in content_tokens: + piece = tokenizer.id_to_byte_piece(t, SpecialTokenPolicy.IGNORE) + if piece.strip(b" \t\r\n") == b"": + ws_prefix_len += 1 + else: + break + reasoning_fail = ws_prefix_len + else: + reasoning_fail = None cases.append( TestCase( tokenizer=tokenizer, @@ -306,26 +329,45 @@ def _generate_general_cases(mistral_tokenizer: MistralTokenizer) -> list[TestCas return cases +def _count_prefix_tokens(tokenizer: Tokenizer, full_text: str, prefix: str) -> int: + r"""Counts the number of tokens that cover the prefix bytes in the full-text tokenization. + + BPE tokenization is context-dependent, so encoding a prefix in isolation may produce + different tokens than encoding the full string. This helper encodes the full string + and counts how many tokens are needed to cover the byte-length of the prefix. + + Args: + tokenizer: The tokenizer to use. + full_text: The complete text to tokenize. + prefix: The prefix whose byte-length determines the token count. + + Returns: + The number of tokens covering the prefix bytes. + """ + assert is_tekkenizer(tokenizer) + prefix_byte_len = len(prefix.encode("utf-8")) + tokens = tokenizer.encode(full_text, bos=False, eos=False) + byte_count = 0 + for i, t in enumerate(tokens): + byte_count += len(tokenizer.id_to_byte_piece(t, SpecialTokenPolicy.IGNORE)) + if byte_count >= prefix_byte_len: + return i + 1 + return len(tokens) + + def _generate_emoji_cases(mistral_tokenizer: MistralTokenizer) -> list[TestCase]: instruct_tokenizer = mistral_tokenizer.instruct_tokenizer tokenizer = instruct_tokenizer.tokenizer assert isinstance(instruct_tokenizer, InstructTokenizerBase) - - # Emoji grammar needs a full vocab tokenizer (tiny programmatic tokenizers lack emoji tokens) - if tokenizer.n_words <= 1000: - return [] - emoji_lark = EMOJI_LARK_PATH.read_text(encoding="utf-8") cases: list[TestCase] = [] - # Encode raw emoji text (no assistant message wrapping) so we get pure emoji tokens items: dict[str, tuple[str, int | None]] = { "emojis_valid_a": ("πŸ˜ƒπŸ˜‚πŸ˜ŠπŸ˜πŸ˜˜πŸ˜—πŸ˜™πŸ˜šπŸ˜‹πŸ˜›πŸ˜œπŸ˜πŸ€‘πŸ€—πŸ€”πŸ€πŸ˜πŸ˜‘πŸ˜ΆπŸ˜¬", None), "emojis_valid_b": ("πŸ˜ƒπŸ˜ƒπŸ˜ƒ", None), - "emojis_invalid_text": ("πŸ˜ƒsmile", len(tokenizer.encode("πŸ˜ƒ", bos=False, eos=False))), - "emojis_invalid_space": ("πŸ˜ƒ ", len(tokenizer.encode("πŸ˜ƒ", bos=False, eos=False))), + "emojis_invalid_text": ("πŸ˜ƒsmile", _count_prefix_tokens(tokenizer, "πŸ˜ƒsmile", "πŸ˜ƒ")), + "emojis_invalid_space": ("πŸ˜ƒ ", _count_prefix_tokens(tokenizer, "πŸ˜ƒ ", "πŸ˜ƒ")), } for case_name, (text, should_fail_on) in items.items(): - # Use raw encode to avoid assistant message wrapping (BOS/EOS) tokens = tokenizer.encode(text, bos=False, eos=False) cases.append( TestCase( @@ -350,7 +392,7 @@ def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[Test ( "single_fcall", [ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}'))], - {"auto": None, "any": None, "none": 0}, + {"auto": None, "any": None, "none": 0, "required": None}, ), ( "multi_fcall", @@ -359,7 +401,7 @@ def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[Test ToolCall(function=FunctionCall(name="hello_1", arguments='{"arg1": "val1", "arg2": "val2"}')), ToolCall(function=FunctionCall(name="hello_2_3", arguments='{"arg1": "val1", "arg2": "val2"}')), ], - {"auto": None, "any": None, "none": 0}, + {"auto": None, "any": None, "none": 0, "required": None}, ), ( "emoji_fcall", @@ -368,7 +410,7 @@ def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[Test function=FunctionCall(name="he🧦🧦o", arguments='{"arg1": "🐱", "arg2": "🐢", "arg🧦": "🧦"}'), ) ], - {"auto": None, "any": None, "none": 0}, + {"auto": None, "any": None, "none": 0, "required": None}, ), ( "pretty_printed_args", @@ -380,12 +422,12 @@ def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[Test ), ) ], - {"auto": None, "any": None, "none": 0}, + {"auto": None, "any": None, "none": 0, "required": None}, ), ( "japanese_fcall", [ToolCall(function=FunctionCall(name="こんにけは", arguments='{"こん": "にけは"}'))], - {"auto": None, "any": None, "none": 0}, + {"auto": None, "any": None, "none": 0, "required": None}, ), ] for case_name, content, valid_for in items: @@ -400,25 +442,11 @@ def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[Test mode=mode, ) ) - # The "required" mode should behave like "any" - tool calls are required - # and text content is optional (content? fcalls) - if case_name != "single_non_strict_tool_call": # handled in strict tests - cases.append( - TestCase( - tokenizer=tokenizer, - tokens=tokens, - should_fail_on=valid_for.get("any"), - case_name=case_name, - mode="required", - ) - ) - # v11: plain_text_think mandates first, so bare tool calls fail at 0 - # v13+: think grammar has think? (optional), tool calls pass; fcalls allows content? prefix if tokenizer.version < TokenizerVersion.v13: - reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "any": 0, "none": 0} + reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "any": 0, "none": 0, "required": 0} else: - reasoning_valid_for = {"auto": None, "any": None, "none": 0} + reasoning_valid_for = {"auto": None, "any": None, "none": 0, "required": None} for mode, should_fail_on in reasoning_valid_for.items(): cases.append( TestCase( @@ -430,23 +458,8 @@ def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[Test reasoning=True, ) ) - # "required" mode with reasoning - if tokenizer.version < TokenizerVersion.v13: - reasoning_required_fail = 0 - else: - reasoning_required_fail = None - cases.append( - TestCase( - tokenizer=tokenizer, - tokens=tokens, - should_fail_on=reasoning_required_fail, - case_name=f"{case_name}_reasoning", - mode="required", - reasoning=True, - ) - ) - # Broken / missing args edge cases (token-level construction) + # Broken / missing args edge cases token_items: list[tuple[str, list[int], dict[Mode, int | None]]] = [ ( "fcall_broken_args", @@ -498,9 +511,6 @@ def _generate_cases_text_and_tool_calls(mistral_tokenizer: MistralTokenizer) -> tokens = _encode_content(instruct_tokenizer, content) text_len = len(tokenizer.encode("Hello!", bos=False, eos=False)) - # Non-reasoning uses base grammar where: - # - "any" mode is `body: fcalls` β€” no content allowed - # - "required" mode is `body: content? fcalls` β€” content is optional valid_for: dict[Mode, int | None] = {"auto": None, "any": 0, "none": text_len, "required": None} for mode, should_fail_on in valid_for.items(): @@ -514,8 +524,6 @@ def _generate_cases_text_and_tool_calls(mistral_tokenizer: MistralTokenizer) -> ) ) - # v11: plain_text_think mandates first, so text+fcall without think fails at 0 - # v13+: think? optional, fcalls: content? fcall, so "any" mode accepts text before tool calls if tokenizer.version < TokenizerVersion.v13: reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "any": 0, "none": 0, "required": 0} else: @@ -696,7 +704,7 @@ def _generate_single_tool_call(mistral_tokenizer: MistralTokenizer) -> list[Test cases: list[TestCase] = [] single_call = [ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}'))] single_tokens = _encode_content(instruct_tokenizer, single_call) - for mode in _AUTO_ANY: + for mode in _AUTO_ANY_REQUIRED: cases.append( TestCase( tokenizer=tokenizer, @@ -707,17 +715,6 @@ def _generate_single_tool_call(mistral_tokenizer: MistralTokenizer) -> list[Test parallel_tool_calls=False, ) ) - # "required" mode also accepts single tool call - cases.append( - TestCase( - tokenizer=tokenizer, - tokens=single_tokens, - should_fail_on=None, - case_name="single_tool_call", - mode="required", - parallel_tool_calls=False, - ) - ) cases.append( TestCase( tokenizer=tokenizer, @@ -729,8 +726,6 @@ def _generate_single_tool_call(mistral_tokenizer: MistralTokenizer) -> list[Test ) ) - # v11: plain_text_think mandates first, so bare tool calls fail at 0 - # v13+: think? optional, tool calls pass if tokenizer.version < TokenizerVersion.v13: reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "any": 0, "none": 0, "required": 0} else: @@ -748,9 +743,6 @@ def _generate_single_tool_call(mistral_tokenizer: MistralTokenizer) -> list[Test ) ) - # Multi tool call should fail when parallel_tool_calls=False - # Each tool call is a separate [TOOL_CALLS]...[ARGS]... sequence; - # the second [TOOL_CALLS] is where it fails. multi_calls = [ ToolCall(function=FunctionCall(name="fn1", arguments='{"arg1": "val1", "arg2": "val2"}')), ToolCall(function=FunctionCall(name="fn2", arguments='{"arg1": "val1", "arg2": "val2"}')), @@ -759,7 +751,7 @@ def _generate_single_tool_call(mistral_tokenizer: MistralTokenizer) -> list[Test single_tokens_with_eos = _encode_content(instruct_tokenizer, [multi_calls[0]]) fail_idx = len(single_tokens_with_eos) - 1 - for mode in _AUTO_ANY: + for mode in _AUTO_ANY_REQUIRED: cases.append( TestCase( tokenizer=tokenizer, @@ -770,17 +762,6 @@ def _generate_single_tool_call(mistral_tokenizer: MistralTokenizer) -> list[Test parallel_tool_calls=False, ) ) - # "required" mode with parallel_tool_calls=False also fails on second tool call - cases.append( - TestCase( - tokenizer=tokenizer, - tokens=multi_tokens, - should_fail_on=fail_idx, - case_name="multi_tool_call_disallowed", - mode="required", - parallel_tool_calls=False, - ) - ) return cases @@ -796,7 +777,7 @@ def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: Gr # 1. Non-strict tools β€” any function name/args accepted non_strict_call = [ToolCall(function=FunctionCall(name="fn1", arguments='{"arg1": "val1", "arg2": "val2"}'))] non_strict_tokens = _encode_content(instruct_tokenizer, non_strict_call) - for mode in _AUTO_ANY: + for mode in _AUTO_ANY_REQUIRED: cases.append( TestCase( tokenizer=tokenizer, @@ -807,24 +788,13 @@ def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: Gr tools=[ToolProvider.retrieve_payment_date(strict=False)], ) ) - # required mode also accepts non-strict tool calls - cases.append( - TestCase( - tokenizer=tokenizer, - tokens=non_strict_tokens, - should_fail_on=None, - case_name="single_non_strict_tool_call", - mode="required", - tools=[ToolProvider.retrieve_payment_date(strict=False)], - ) - ) # 2. Correct strict tool call strict_call = [ ToolCall(function=FunctionCall(name="retrieve_payment_date", arguments='{"transaction_id": "12345"}')) ] strict_tokens = _encode_content(instruct_tokenizer, strict_call) - for mode in _AUTO_ANY: + for mode in _AUTO_ANY_REQUIRED: cases.append( TestCase( tokenizer=tokenizer, @@ -835,23 +805,12 @@ def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: Gr tools=[ToolProvider.retrieve_payment_date(strict=True)], ) ) - # required mode also accepts correct strict tool call - cases.append( - TestCase( - tokenizer=tokenizer, - tokens=strict_tokens, - should_fail_on=None, - case_name="single_strict_tool_call", - mode="required", - tools=[ToolProvider.retrieve_payment_date(strict=True)], - ) - ) # 3. Wrong args for strict tool β€” it must fail somewhere before the end. wrong_args_call = [ToolCall(function=FunctionCall(name="retrieve_payment_date", arguments='{"bogus": "12345"}'))] wrong_args_tokens = _encode_content(instruct_tokenizer, wrong_args_call) - bogus_start = _find_first_rejection(factory, wrong_args_tokens, mode="auto", tools=tools_strict) - for mode in _AUTO_ANY: + bogus_start = _find_first_rejection(factory, wrong_args_tokens, mode=ToolChoiceEnum.auto, tools=tools_strict) + for mode in _AUTO_ANY_REQUIRED: cases.append( TestCase( tokenizer=tokenizer, @@ -862,23 +821,12 @@ def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: Gr tools=tools_strict, ) ) - # required mode also fails on wrong args - cases.append( - TestCase( - tokenizer=tokenizer, - tokens=wrong_args_tokens, - should_fail_on=bogus_start, - case_name="strict_tool_call_wrong_args", - mode="required", - tools=tools_strict, - ) - ) # 4. Wrong name for strict tool wrong_name_call = [ToolCall(function=FunctionCall(name="fn1", arguments='{"transaction_id": "12345"}'))] wrong_name_tokens = _encode_content(instruct_tokenizer, wrong_name_call) - fail_on_name = _find_first_rejection(factory, wrong_name_tokens, mode="auto", tools=tools_strict) - for mode in _AUTO_ANY: + fail_on_name = _find_first_rejection(factory, wrong_name_tokens, mode=ToolChoiceEnum.auto, tools=tools_strict) + for mode in _AUTO_ANY_REQUIRED: cases.append( TestCase( tokenizer=tokenizer, @@ -889,17 +837,6 @@ def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: Gr tools=tools_strict, ) ) - # required mode also fails on wrong name - cases.append( - TestCase( - tokenizer=tokenizer, - tokens=wrong_name_tokens, - should_fail_on=fail_on_name, - case_name="strict_tool_call_wrong_name", - mode="required", - tools=tools_strict, - ) - ) # 5. Multiple strict tool calls (both correct) multi_strict = [ @@ -907,7 +844,7 @@ def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: Gr ToolCall(function=FunctionCall(name="retrieve_payment_status", arguments='{"transaction_id": "12345"}')), ] multi_strict_tokens = _encode_content(instruct_tokenizer, multi_strict) - for mode in _AUTO_ANY: + for mode in _AUTO_ANY_REQUIRED: cases.append( TestCase( tokenizer=tokenizer, @@ -921,26 +858,10 @@ def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: Gr ], ) ) - # required mode also accepts multiple strict tool calls - cases.append( - TestCase( - tokenizer=tokenizer, - tokens=multi_strict_tokens, - should_fail_on=None, - case_name="multiple_strict_tool_calls", - mode="required", - tools=[ - ToolProvider.retrieve_payment_date(strict=True), - ToolProvider.retrieve_payment_status(strict=True), - ], - ) - ) # 6. reasoning=True variants - # v11: plain_text_think mandates first, so all bare tool calls fail at 0 - # v13+: think? optional, tool calls behave as without reasoning if tokenizer.version < TokenizerVersion.v13: - for mode in _AUTO_ANY: + for mode in _AUTO_ANY_REQUIRED: cases.append( TestCase( tokenizer=tokenizer, @@ -952,20 +873,8 @@ def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: Gr reasoning=True, ) ) - # required mode with reasoning in v11 also fails at 0 - cases.append( - TestCase( - tokenizer=tokenizer, - tokens=strict_tokens, - should_fail_on=0, - case_name="strict_tool_call_reasoning", - mode="required", - tools=tools_strict, - reasoning=True, - ) - ) else: - for mode in _AUTO_ANY: + for mode in _AUTO_ANY_REQUIRED: cases.append( TestCase( tokenizer=tokenizer, @@ -977,19 +886,7 @@ def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: Gr reasoning=True, ) ) - # required mode with reasoning in v13+ also accepts strict tool call - cases.append( - TestCase( - tokenizer=tokenizer, - tokens=strict_tokens, - should_fail_on=None, - case_name="strict_tool_call_reasoning", - mode="required", - tools=tools_strict, - reasoning=True, - ) - ) - for mode in _AUTO_ANY: + for mode in _AUTO_ANY_REQUIRED: cases.append( TestCase( tokenizer=tokenizer, @@ -1004,21 +901,6 @@ def _generate_strict_tool_calls(mistral_tokenizer: MistralTokenizer, factory: Gr reasoning=True, ) ) - # required mode with reasoning in v13+ also accepts multiple strict tool calls - cases.append( - TestCase( - tokenizer=tokenizer, - tokens=multi_strict_tokens, - should_fail_on=None, - case_name="multiple_strict_tool_calls_reasoning", - mode="required", - tools=[ - ToolProvider.retrieve_payment_date(strict=True), - ToolProvider.retrieve_payment_status(strict=True), - ], - reasoning=True, - ) - ) return cases @@ -1031,7 +913,6 @@ def _generate_named_tool_choice(mistral_tokenizer: MistralTokenizer, factory: Gr cases: list[TestCase] = [] - # Define the tools we'll use tools = [ ToolProvider.retrieve_payment_date(strict=False), ToolProvider.retrieve_payment_status(strict=False), @@ -1052,7 +933,6 @@ def _generate_named_tool_choice(mistral_tokenizer: MistralTokenizer, factory: Gr ] correct_date_tokens = _encode_content(instruct_tokenizer, correct_date_call) - # The named tool choice should accept only that specific tool cases.append( TestCase( tokenizer=tokenizer, @@ -1077,7 +957,6 @@ def _generate_named_tool_choice(mistral_tokenizer: MistralTokenizer, factory: Gr ) # 2. Non-strict NamedToolChoice should NOT enforce JSON arguments schema β€” - # any valid JSON object is accepted (uses %json {"type": "object"}) arbitrary_args_call = [ ToolCall( function=FunctionCall( @@ -1110,7 +989,6 @@ def _generate_named_tool_choice(mistral_tokenizer: MistralTokenizer, factory: Gr ] wrong_tool_tokens = _encode_content(instruct_tokenizer, wrong_tool_call) - # Find where the rejection happens fail_idx = _find_first_rejection( factory, wrong_tool_tokens, @@ -1130,8 +1008,6 @@ def _generate_named_tool_choice(mistral_tokenizer: MistralTokenizer, factory: Gr ) # 4. NamedToolChoice with reasoning mode - # v11: plain_text_think mandates first - # v13+: think? optional if tokenizer.version < TokenizerVersion.v13: cases.append( TestCase( @@ -1164,7 +1040,6 @@ def _generate_named_tool_choice(mistral_tokenizer: MistralTokenizer, factory: Gr function=FunctionName(name="retrieve_payment_date"), ) - # Correct args cases.append( TestCase( tokenizer=tokenizer, @@ -1176,7 +1051,6 @@ def _generate_named_tool_choice(mistral_tokenizer: MistralTokenizer, factory: Gr ) ) - # Wrong args wrong_args_call = [ ToolCall( function=FunctionCall( @@ -1204,7 +1078,6 @@ def _generate_named_tool_choice(mistral_tokenizer: MistralTokenizer, factory: Gr ) # 6. NamedToolChoice with non-existent tool in tools list - # This should generate a grammar that only accepts the named tool named_nonexistent = NamedToolChoice( type=ToolTypes.function, function=FunctionName(name="non_existent_tool"), @@ -1219,7 +1092,6 @@ def _generate_named_tool_choice(mistral_tokenizer: MistralTokenizer, factory: Gr ] nonexistent_tokens = _encode_content(instruct_tokenizer, nonexistent_call) - # Without strict tool, any args are accepted cases.append( TestCase( tokenizer=tokenizer, @@ -1227,22 +1099,44 @@ def _generate_named_tool_choice(mistral_tokenizer: MistralTokenizer, factory: Gr should_fail_on=None, case_name="named_tool_choice_nonexistent_tool", mode=named_nonexistent, - tools=[], # No tools provided + tools=[], ) ) return cases -def _generate_json_schema(mistral_tokenizer: MistralTokenizer) -> list[TestCase]: +def _find_first_json_schema_rejection( + factory: GrammarFactory, + tokens: list[int], + json_schema: dict[str, Any], +) -> int: + r"""Finds the index of the first token rejected by the JSON schema grammar. + + Args: + factory: The grammar factory. + tokens: The token sequence to test. + json_schema: The JSON schema to validate against. + + Returns: + The index of the first rejected token. + + Raises: + ValueError: If all tokens are accepted. + """ + grammar = factory.get_lark_for_json_schema(json_schema=json_schema) + matcher = factory.get_matcher(grammar) + for i, token in enumerate(tokens): + if not matcher.consume_token(token): + return i + raise ValueError("All tokens were accepted β€” expected a rejection") + + +def _generate_json_schema(mistral_tokenizer: MistralTokenizer, factory: GrammarFactory) -> list[TestCase]: instruct_tokenizer = mistral_tokenizer.instruct_tokenizer tokenizer = instruct_tokenizer.tokenizer assert isinstance(instruct_tokenizer, InstructTokenizerBase) - # JSON schema validation requires a full vocab tokenizer - if tokenizer.n_words <= 1000: - return [] - cases: list[TestCase] = [] items: list[tuple[str, str, int | None, dict[str, Any]]] = [ ( @@ -1269,12 +1163,6 @@ def _generate_json_schema(mistral_tokenizer: MistralTokenizer) -> list[TestCase] 0, {"type": "object"}, ), - ( - "basic_person_invalid", - '{"age": "John", "name": 30}', - 1, - SchemaProvider.basic_person(), - ), ( "basic_person_non_strict_valid", '{"age": "John", "name": 30}', @@ -1301,6 +1189,22 @@ def _generate_json_schema(mistral_tokenizer: MistralTokenizer) -> list[TestCase] json_schema=json_schema, ) ) + + # Cases where should_fail_on must be computed dynamically because the exact + # rejection index depends on the grammar engine's internal byte-level parsing. + person_invalid_tokens = _encode_content(instruct_tokenizer, '{"age": "John", "name": 30}') + person_schema = SchemaProvider.basic_person() + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=person_invalid_tokens, + should_fail_on=_find_first_json_schema_rejection(factory, person_invalid_tokens, person_schema), + case_name="basic_person_invalid", + mode="auto", + json_schema=person_schema, + ) + ) + return cases @@ -1311,14 +1215,14 @@ def _generate_cases(mistral_tokenizer: MistralTokenizer, factory: GrammarFactory cases = _generate_general_cases(mistral_tokenizer) cases += _generate_emoji_cases(mistral_tokenizer) - cases += _generate_json_schema(mistral_tokenizer) + cases += _generate_json_schema(mistral_tokenizer, factory) cases += _generate_cases_tool_calls(mistral_tokenizer) cases += _generate_single_tool_call(mistral_tokenizer) cases += _generate_strict_tool_calls(mistral_tokenizer, factory) cases += _generate_named_tool_choice(mistral_tokenizer, factory) cases += _generate_cases_text_and_tool_calls(mistral_tokenizer) - if not tokenizer_version < TokenizerVersion.v13: + if tokenizer_version >= TokenizerVersion.v13: cases += _generate_cases_thinking(mistral_tokenizer) else: cases += _generate_cases_thinking_v11(mistral_tokenizer) From 1487e7a2a939bdad10e4a78f01031a4904fa26cb Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Wed, 25 Mar 2026 20:38:40 +0100 Subject: [PATCH 07/21] Improve coverage --- tests/guidance/test_guidance.py | 107 ++++++++++++++++++++++++++++--- tests/guidance/test_tokenizer.py | 44 ++++++++++++- 2 files changed, 142 insertions(+), 9 deletions(-) diff --git a/tests/guidance/test_guidance.py b/tests/guidance/test_guidance.py index 73feaa1d..db4fc514 100644 --- a/tests/guidance/test_guidance.py +++ b/tests/guidance/test_guidance.py @@ -6,7 +6,7 @@ import pytest from pydantic import BaseModel, ConfigDict, Field, model_validator -from mistral_common.guidance.grammar_factory import GrammarFactory +from mistral_common.guidance.grammar_factory import GrammarFactory, convert_tool_calls from mistral_common.protocol.instruct.chunk import TextChunk, ThinkChunk from mistral_common.protocol.instruct.messages import AssistantMessage from mistral_common.protocol.instruct.normalize import get_normalizer @@ -153,22 +153,18 @@ class SchemaProvider: @staticmethod def basic_person() -> dict[str, Any]: class Person(BaseModel): + model_config = ConfigDict(extra="forbid") name: str age: int - class Config: - extra = "forbid" - return Person.model_json_schema() @staticmethod def basic_dict_of_list() -> dict[str, Any]: class DoMerge(BaseModel): + model_config = ConfigDict(extra="forbid") new_clusters: dict[str, list[str]] = Field(default_factory=dict) - class Config: - extra = "forbid" - return DoMerge.model_json_schema() @@ -224,7 +220,9 @@ def _encode_content( is_before_last_user_message=False, continue_message=False, ) - # Strip trailing EOS before adding tool calls + # The instruct tokenizer appends EOS after content, but when tool calls follow, + # the EOS should come after the last tool call, not after the content. Strip it + # here so the tool call tokens are appended directly after content tokens. while tokens and tokens[-1] == tokenizer.eos_id: tokens.pop() @@ -1304,6 +1302,15 @@ def test_grammar(self, test_case: TestCase) -> None: f"\n\nGrammar:\n{grammar}\n\nTokens: {debug_tokens}\n\nIds: {test_case.tokens}" ) + # For fully accepted sequences, verify the matcher reached a valid terminal state. + # Raw lark grammars (e.g., emoji matcher) don't consume EOS, so they may not reach + # a stopped state β€” skip the check for those. + if test_case.should_fail_on is None and test_case.raw_lark is None: + assert matcher.is_stopped(), ( + f"Matcher did not reach terminal state after consuming all tokens.\n\n" + f"Grammar:\n{grammar}\n\nTokens: {debug_tokens}\n\nIds: {test_case.tokens}" + ) + @pytest.mark.parametrize( ("tokenizer", "expected"), [ @@ -1315,3 +1322,87 @@ def test_grammar(self, test_case: TestCase) -> None: ) def test_grammar_factory_is_supported(self, tokenizer: MistralTokenizer, expected: bool) -> None: assert GrammarFactory.is_supported(tokenizer) is expected + + @pytest.mark.parametrize( + "tokenizer", + [MistralTokenizer.v1(), MistralTokenizer.v3(is_tekken=True)], + ) + def test_grammar_factory_init_rejects_unsupported(self, tokenizer: MistralTokenizer) -> None: + with pytest.raises(ValueError, match="Guidance requires a Tekken tokenizer with version >= v11"): + GrammarFactory(tokenizer) + + def test_get_matcher_rejects_invalid_grammar(self, v11_tekken: MistralTokenizer) -> None: + factory = GrammarFactory(v11_tekken) + with pytest.raises(ValueError, match="Invalid grammar"): + factory.get_matcher("start: INVALID_RULE_REF_THAT_DOES_NOT_EXIST") + + +class TestConvertToolCalls: + def test_none_mode(self) -> None: + result = convert_tool_calls(tools=None, mode=ToolChoiceEnum.none, parallel_tool_calls=False) + assert result == "" + + def test_none_mode_with_tools(self) -> None: + tools = [ToolProvider.retrieve_payment_date(strict=True)] + result = convert_tool_calls(tools=tools, mode=ToolChoiceEnum.none, parallel_tool_calls=True) + assert result == "" + + def test_auto_mode_no_tools(self) -> None: + result = convert_tool_calls(tools=None, mode=ToolChoiceEnum.auto, parallel_tool_calls=False) + assert "" in result + assert "" in result + assert "/.+/" in result + assert not result.endswith(")+") + + def test_auto_mode_non_strict(self) -> None: + tools = [ToolProvider.retrieve_payment_date(strict=False)] + result = convert_tool_calls(tools=tools, mode=ToolChoiceEnum.auto, parallel_tool_calls=False) + assert "" in result + assert "/.+/" in result + assert not result.endswith(")+") + + def test_auto_mode_strict(self) -> None: + tools = [ + ToolProvider.retrieve_payment_date(strict=True), + ToolProvider.retrieve_payment_status(strict=True), + ] + result = convert_tool_calls(tools=tools, mode=ToolChoiceEnum.auto, parallel_tool_calls=False) + assert '"retrieve_payment_date"' in result + assert '"retrieve_payment_status"' in result + assert not result.endswith(")+") + + def test_named_tool_choice_non_strict(self) -> None: + named = NamedToolChoice(function=FunctionName(name="retrieve_payment_date")) + tools = [ToolProvider.retrieve_payment_date(strict=False)] + result = convert_tool_calls(tools=tools, mode=named, parallel_tool_calls=False) + assert '"retrieve_payment_date"' in result + assert "/.+/" not in result + assert not result.endswith(")+") + + def test_named_tool_choice_strict(self) -> None: + named = NamedToolChoice(function=FunctionName(name="retrieve_payment_date")) + tools = [ + ToolProvider.retrieve_payment_date(strict=True), + ToolProvider.retrieve_payment_status(strict=True), + ] + result = convert_tool_calls(tools=tools, mode=named, parallel_tool_calls=False) + assert '"retrieve_payment_date"' in result + assert '"retrieve_payment_status"' not in result + assert not result.endswith(")+") + + def test_parallel_tool_calls(self) -> None: + result = convert_tool_calls(tools=None, mode=ToolChoiceEnum.auto, parallel_tool_calls=True) + assert result.startswith("(") and result.endswith(")+") + + def test_empty_params_strict_tool(self) -> None: + tool = Tool(function=Function(name="empty_fn", parameters={}, strict=True)) + result = convert_tool_calls(tools=[tool], mode=ToolChoiceEnum.auto, parallel_tool_calls=False) + assert '"additionalProperties": false' in result + assert '"properties": {}' in result + assert not result.endswith(")+") + + def test_named_tool_not_in_strict_tools_raises(self) -> None: + named = NamedToolChoice(function=FunctionName(name="non_existent_tool")) + tools = [ToolProvider.retrieve_payment_date(strict=True)] + with pytest.raises(StopIteration): + convert_tool_calls(tools=tools, mode=named, parallel_tool_calls=False) diff --git a/tests/guidance/test_tokenizer.py b/tests/guidance/test_tokenizer.py index d302ddf3..491444ec 100644 --- a/tests/guidance/test_tokenizer.py +++ b/tests/guidance/test_tokenizer.py @@ -12,7 +12,7 @@ from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, SpecialTokens, Tokenizer, TokenizerVersion from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV7 from mistral_common.tokens.tokenizers.mistral import MistralTokenizer -from mistral_common.tokens.tokenizers.tekken import Tekkenizer +from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer from tests.test_tekken import get_special_tokens, quick_vocab _NUM_SPECIAL_TOKENS = 100 @@ -115,6 +115,48 @@ def test_call_encodes_string(self, tekkenizer: Tekkenizer, llg_tokenizer: Mistra test_string = "abc" assert llg_tokenizer(test_string) == tekkenizer.encode(test_string, bos=False, eos=False) + def test_init_rejects_invalid_special_token_format(self) -> None: + base = list(Tekkenizer.DEPRECATED_SPECIAL_TOKENS) + next_rank = len(base) + special_tokens: list[SpecialTokenInfo] = [ + *base, + SpecialTokenInfo(rank=next_rank, token_str="INVALID_NO_BRACKETS", is_control=True), + ] + vocab = quick_vocab() + num_special = len(special_tokens) + tekkenizer = Tekkenizer( + vocab, + special_tokens=special_tokens, + pattern=r".+", + vocab_size=len(vocab) + num_special, + num_special_tokens=num_special, + version=TokenizerVersion.v7, + ) + with pytest.raises(ValueError, match="Invalid special token"): + MistralLLGTokenizer(tekkenizer) + + def test_init_rejects_duplicate_special_tokens(self) -> None: + base = list(Tekkenizer.DEPRECATED_SPECIAL_TOKENS) + next_rank = len(base) + # Both [CUSTOM] and map to after bracket conversion + special_tokens: list[SpecialTokenInfo] = [ + *base, + SpecialTokenInfo(rank=next_rank, token_str="", is_control=True), + SpecialTokenInfo(rank=next_rank + 1, token_str="[CUSTOM]", is_control=True), + ] + vocab = quick_vocab() + num_special = len(special_tokens) + tekkenizer = Tekkenizer( + vocab, + special_tokens=special_tokens, + pattern=r".+", + vocab_size=len(vocab) + num_special, + num_special_tokens=num_special, + version=TokenizerVersion.v7, + ) + with pytest.raises(ValueError, match="Duplicate special token"): + MistralLLGTokenizer(tekkenizer) + class TestFromMistralTokenizer: def test_properties(self, tekkenizer: Tekkenizer, ll_tokenizer: llg.LLTokenizer) -> None: From b3a5387428c00faf09c9c91ee31dbcf0811ce570 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Thu, 26 Mar 2026 11:58:49 +0100 Subject: [PATCH 08/21] Refactor attributes to properties --- src/mistral_common/guidance/tokenizer.py | 50 ++++++++++++++---------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/src/mistral_common/guidance/tokenizer.py b/src/mistral_common/guidance/tokenizer.py index 57cbdb93..2a3f803b 100644 --- a/src/mistral_common/guidance/tokenizer.py +++ b/src/mistral_common/guidance/tokenizer.py @@ -11,19 +11,27 @@ class MistralLLGTokenizer: - r"""Wraps a Tekken tokenizer for use with llguidance. + r"""Wraps a Tekken tokenizer for use with llguidance.""" - Attributes: - eos_token_id: The end of string token id. - bos_token_id: The beginning of string token id. - tokens: The list of token byte representations. - special_token_ids: The list of special token ids. - """ + @property + def bos_token_id(self) -> int: + r"""The beginning of string token id.""" + return self._tokenizer.bos_id + + @property + def eos_token_id(self) -> int: + r"""The end of string token id.""" + return self._tokenizer.eos_id + + @property + def tokens(self) -> list[bytes]: + r"""The list of token byte representations.""" + return self._tokens - eos_token_id: int - bos_token_id: int - tokens: list[bytes] - special_token_ids: list[int] + @property + def special_token_ids(self) -> list[bytes]: + r"""The list of special token ids.""" + return self._special_token_ids def __init__(self, tokenizer: Tokenizer) -> None: r"""Initialize the wrapper. @@ -36,14 +44,13 @@ def __init__(self, tokenizer: Tokenizer) -> None: ValueError: If a special token has an invalid format. """ assert_llguidance_installed() + if not is_tekkenizer(tokenizer): raise TypeError(f"Guidance only supports Tekken tokenizers, got {type(tokenizer)}") - self._tokenizer = tokenizer - self.eos_token_id = self._tokenizer.eos_id - self.bos_token_id = self._tokenizer.bos_id - self.tokens: list[bytes] = [] - self.special_token_ids: list[int] = [] + self._tokenizer = tokenizer + self._tokens: list[bytes] = [] + self._special_token_ids: list[int] = [] seen_special_tokens: set[str] = set() for i in range(self._tokenizer.n_words): @@ -61,15 +68,16 @@ def __init__(self, tokenizer: Tokenizer) -> None: if token_rep_llg in seen_special_tokens: raise ValueError(f"Duplicate special token: {token_rep_llg} (already seen: {seen_special_tokens})") seen_special_tokens.add(token_rep_llg) - self.special_token_ids.append(i) - self.tokens.append(token_rep_llg.encode("utf-8")) + self._special_token_ids.append(i) + self._tokens.append(token_rep_llg.encode("utf-8")) else: token_bytes = self._tokenizer.id_to_byte_piece(i, SpecialTokenPolicy.RAISE) - self.tokens.append(token_bytes) + self._tokens.append(token_bytes) - if len(self.special_token_ids) != self._tokenizer.num_special_tokens: + if len(self._special_token_ids) != self._tokenizer.num_special_tokens: raise ValueError( - f"Expected {self._tokenizer.num_special_tokens} special tokens, but found {len(self.special_token_ids)}" + f"Expected {self._tokenizer.num_special_tokens} special tokens, but found " + f"{len(self._special_token_ids)}" ) def __call__(self, s: str, *args: Any, **kwargs: Any) -> list[int]: From bc0ef69a8262abf2bb792b02301a1c43f5d9b393 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Thu, 26 Mar 2026 15:28:11 +0100 Subject: [PATCH 09/21] mypy --- src/mistral_common/guidance/tokenizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mistral_common/guidance/tokenizer.py b/src/mistral_common/guidance/tokenizer.py index 2a3f803b..a37f789d 100644 --- a/src/mistral_common/guidance/tokenizer.py +++ b/src/mistral_common/guidance/tokenizer.py @@ -29,7 +29,7 @@ def tokens(self) -> list[bytes]: return self._tokens @property - def special_token_ids(self) -> list[bytes]: + def special_token_ids(self) -> list[int]: r"""The list of special token ids.""" return self._special_token_ids From 0ebe85d5e3114982e679d6b7078cb4790dde38b4 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Fri, 27 Mar 2026 15:11:40 +0100 Subject: [PATCH 10/21] Update json grammar --- .../guidance/data/base_grammar.lark.jinja | 6 +- .../data/plain_text_think_grammar.lark.jinja | 8 +- .../guidance/data/think_grammar.lark.jinja | 14 +- .../guidance/grammar_factory.py | 30 ++- tests/guidance/test_guidance.py | 212 +++++++++++++++++- 5 files changed, 260 insertions(+), 10 deletions(-) diff --git a/src/mistral_common/guidance/data/base_grammar.lark.jinja b/src/mistral_common/guidance/data/base_grammar.lark.jinja index b43b1998..84fbd433 100644 --- a/src/mistral_common/guidance/data/base_grammar.lark.jinja +++ b/src/mistral_common/guidance/data/base_grammar.lark.jinja @@ -1,5 +1,9 @@ {% if json_schema_str != None -%} +{% if json_only -%} +start: SAFE_WS? %json {{ json_schema_str }} +{% else -%} start: body | SAFE_WS? %json {{ json_schema_str }} +{% endif -%} {% else -%} start: body {% endif -%} @@ -18,4 +22,4 @@ fcalls: {{ fcall }} content: (/(.|\n)+/)+ -SAFE_WS: /[ \t\r\n]+/ +SAFE_WS: /[ \t\r\n]{1,8}/ diff --git a/src/mistral_common/guidance/data/plain_text_think_grammar.lark.jinja b/src/mistral_common/guidance/data/plain_text_think_grammar.lark.jinja index 489d7271..7c98243d 100644 --- a/src/mistral_common/guidance/data/plain_text_think_grammar.lark.jinja +++ b/src/mistral_common/guidance/data/plain_text_think_grammar.lark.jinja @@ -1,5 +1,9 @@ {% if json_schema_str != None -%} -start: body | %json {{ json_schema_str }} +{% if json_only -%} +start: SAFE_WS? %json {{ json_schema_str }} +{% else -%} +start: body | SAFE_WS? %json {{ json_schema_str }} +{% endif -%} {% else -%} start: body {% endif -%} @@ -22,4 +26,4 @@ content: NO_THINK think: SAFE_WS? // text_first_optional end_think -SAFE_WS: /[ \t\r\n]+/ +SAFE_WS: /[ \t\r\n]{1,8}/ diff --git a/src/mistral_common/guidance/data/think_grammar.lark.jinja b/src/mistral_common/guidance/data/think_grammar.lark.jinja index 6acc20e6..f49e417c 100644 --- a/src/mistral_common/guidance/data/think_grammar.lark.jinja +++ b/src/mistral_common/guidance/data/think_grammar.lark.jinja @@ -1,5 +1,15 @@ {% if json_schema_str != None -%} -start: body | %json {{ json_schema_str }} +{% if json_only -%} +start: json_grammar +{% else -%} +start: body | json_grammar +{% endif -%} + +{% if think_with_json -%} +json_grammar: think? SAFE_WS? %json {{ json_schema_str }} +{% else -%} +json_grammar: SAFE_WS? %json {{ json_schema_str }} +{% endif -%} {% else -%} start: body {% endif -%} @@ -18,4 +28,4 @@ content: (/(.|\n)+/)+ think: content -SAFE_WS: /[ \t\r\n]+/ +SAFE_WS: /[ \t\r\n]{1,8}/ diff --git a/src/mistral_common/guidance/grammar_factory.py b/src/mistral_common/guidance/grammar_factory.py index 08b186b1..444c0b4e 100644 --- a/src/mistral_common/guidance/grammar_factory.py +++ b/src/mistral_common/guidance/grammar_factory.py @@ -42,6 +42,8 @@ def _cached_get_lark_from_jinja( fcall: str, json_schema_str: str | None, parallel_tool_calls: bool, + json_only: bool = False, + think_with_json: bool = False, ) -> str: jinja_template = Template(template) lark_grammar = jinja_template.render( @@ -49,6 +51,8 @@ def _cached_get_lark_from_jinja( fcall=fcall, json_schema_str=json_schema_str, parallel_tool_calls=parallel_tool_calls, + json_only=json_only, + think_with_json=think_with_json, ) return lark_grammar @@ -179,6 +183,7 @@ def get_lark_from_jinja( tools: list[Tool] | None, json_schema: dict[str, Any] | None, parallel_tool_calls: bool, + json_only: bool = False, ) -> str: r"""Renders a lark grammar from a jinja template. @@ -188,6 +193,7 @@ def get_lark_from_jinja( tools: The list of tools available. json_schema: JSON schema to additionally allow, unioned with the grammar. parallel_tool_calls: Whether parallel tool calls are allowed. + json_only: If True, generates only JSON schema grammar without text/tool call alternatives. Returns: The rendered lark grammar string. @@ -196,17 +202,35 @@ def get_lark_from_jinja( json_schema_str = json.dumps(json_schema, ensure_ascii=False) if json_schema else None # NamedToolChoice forces a specific tool, which maps to "required" grammar template_mode = ToolChoiceEnum.required if isinstance(mode, NamedToolChoice) else ToolChoiceEnum(mode) + think_with_json = self._tokenizer.version.supports_model_settings return _cached_get_lark_from_jinja( template=template, mode=template_mode.value, fcall=fcall, json_schema_str=json_schema_str, parallel_tool_calls=parallel_tool_calls, + json_only=json_only, + think_with_json=think_with_json, ) - def get_lark_for_json_schema(self, json_schema: dict[str, Any]) -> str: - r"""Returns a lark grammar that only accepts JSON objects matching the given schema.""" - return f"start: SAFE_WS? %json {json.dumps(json_schema, ensure_ascii=False)} \nSAFE_WS: /[ \t\r\n]+/" + def get_lark_for_json_schema(self, template: str, json_schema: dict[str, Any]) -> str: + r"""Returns a lark grammar that only accepts JSON objects matching the given schema. + + Args: + template: Jinja template to render as a string. + json_schema: The JSON schema to validate against. + + Returns: + The rendered lark grammar string that only matches the given JSON schema. + """ + return self.get_lark_from_jinja( + template=template, + mode=ToolChoiceEnum.none, + tools=None, + json_schema=json_schema, + parallel_tool_calls=True, + json_only=True, + ) def get_matcher(self, lark: str) -> "llg.LLMatcher": r"""Creates an LLMatcher from a lark grammar string. diff --git a/tests/guidance/test_guidance.py b/tests/guidance/test_guidance.py index db4fc514..2bbc7de0 100644 --- a/tests/guidance/test_guidance.py +++ b/tests/guidance/test_guidance.py @@ -10,6 +10,7 @@ from mistral_common.protocol.instruct.chunk import TextChunk, ThinkChunk from mistral_common.protocol.instruct.messages import AssistantMessage from mistral_common.protocol.instruct.normalize import get_normalizer +from mistral_common.protocol.instruct.request import ReasoningEffort from mistral_common.protocol.instruct.tool_calls import ( Function, FunctionCall, @@ -27,8 +28,10 @@ InstructTokenizerBase, InstructTokenizerV11, InstructTokenizerV13, + InstructTokenizerV15, ) from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from mistral_common.tokens.tokenizers.model_settings_builder import EnumBuilder, ModelSettingsBuilder from mistral_common.tokens.tokenizers.tekken import Tekkenizer, is_tekkenizer from tests.test_tekken import get_special_tokens, quick_vocab @@ -80,6 +83,7 @@ def _build_tekken_mistral_tokenizer( version: TokenizerVersion, add_think: bool = False, + model_settings_builder: ModelSettingsBuilder | None = None, ) -> MistralTokenizer: r"""Builds a MistralTokenizer wrapping a programmatic Tekkenizer.""" special_tokens = get_special_tokens(version, add_think=add_think) @@ -92,6 +96,7 @@ def _build_tekken_mistral_tokenizer( vocab_size=len(vocab) + _NUM_SPECIAL_TOKENS, num_special_tokens=_NUM_SPECIAL_TOKENS, version=version, + model_settings_builder=model_settings_builder, ) match version: @@ -99,6 +104,8 @@ def _build_tekken_mistral_tokenizer( instruct_tokenizer = InstructTokenizerV11(tekkenizer) case TokenizerVersion.v13: instruct_tokenizer = InstructTokenizerV13(tekkenizer) + case TokenizerVersion.v15: + instruct_tokenizer = InstructTokenizerV15(tekkenizer) case _: raise ValueError(f"Unsupported version for programmatic Tekken build: {version}") @@ -117,6 +124,22 @@ def v13_tekken() -> MistralTokenizer: return _build_tekken_mistral_tokenizer(TokenizerVersion.v13, add_think=True) +_V15_MODEL_SETTINGS_BUILDER = ModelSettingsBuilder( + reasoning_effort=EnumBuilder[ReasoningEffort]( + values=list(ReasoningEffort), + accepts_none=True, + default=None, + ), +) + + +@pytest.fixture(scope="module") +def v15_tekken() -> MistralTokenizer: + return _build_tekken_mistral_tokenizer( + TokenizerVersion.v15, add_think=True, model_settings_builder=_V15_MODEL_SETTINGS_BUILDER + ) + + _PAYMENT_PARAMS: dict[str, Any] = { "type": "object", "additionalProperties": False, @@ -1122,7 +1145,8 @@ def _find_first_json_schema_rejection( Raises: ValueError: If all tokens are accepted. """ - grammar = factory.get_lark_for_json_schema(json_schema=json_schema) + template = factory.select_jinja_template(reasoning=False) + grammar = factory.get_lark_for_json_schema(template=template, json_schema=json_schema) matcher = factory.get_matcher(grammar) for i, token in enumerate(tokens): if not matcher.consume_token(token): @@ -1206,6 +1230,174 @@ def _generate_json_schema(mistral_tokenizer: MistralTokenizer, factory: GrammarF return cases +def _generate_json_schema_reasoning(mistral_tokenizer: MistralTokenizer, factory: GrammarFactory) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + + # Valid JSON should be accepted even under reasoning templates with json_only + items: list[tuple[str, str, int | None, dict[str, Any]]] = [ + ( + "json_schema_reasoning_valid", + '{"name": "John", "age": 30}', + None, + SchemaProvider.basic_person(), + ), + ( + "json_schema_reasoning_whitespace_valid", + '\n {"name": "John", "age": 30}', + None, + {"type": "object"}, + ), + ( + "json_schema_reasoning_invalid_bracket", + '"name": "John", "age": 30}', + 0, + {"type": "object"}, + ), + ] + for case_name, text, should_fail_on, json_schema in items: + tokens = _encode_content(instruct_tokenizer, text) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tokens, + should_fail_on=should_fail_on, + case_name=case_name, + mode="auto", + json_schema=json_schema, + reasoning=True, + ) + ) + + return cases + + +def _generate_json_only_negative(mistral_tokenizer: MistralTokenizer, factory: GrammarFactory) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + + cases: list[TestCase] = [] + + # Plain text should be rejected by json_only grammar + text_tokens = _encode_content(instruct_tokenizer, "Hello world!") + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=text_tokens, + should_fail_on=0, + case_name="json_only_rejects_text", + mode="auto", + json_schema={"type": "object"}, + reasoning=False, + ) + ) + + # Tool calls should be rejected by json_only grammar + tool_call_tokens = _encode_content( + instruct_tokenizer, + [ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1"}'))], + ) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=tool_call_tokens, + should_fail_on=0, + case_name="json_only_rejects_tool_call", + mode="auto", + json_schema={"type": "object"}, + reasoning=False, + ) + ) + + # Plain text should also be rejected with reasoning=True + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=text_tokens, + should_fail_on=0, + case_name="json_only_reasoning_rejects_text", + mode="auto", + json_schema={"type": "object"}, + reasoning=True, + ) + ) + + return cases + + +def _generate_json_schema_think_with_json( + mistral_tokenizer: MistralTokenizer, factory: GrammarFactory +) -> list[TestCase]: + instruct_tokenizer = mistral_tokenizer.instruct_tokenizer + tokenizer = instruct_tokenizer.tokenizer + assert isinstance(instruct_tokenizer, InstructTokenizerBase) + assert isinstance(instruct_tokenizer, InstructTokenizerV13) + + cases: list[TestCase] = [] + + json_schema = SchemaProvider.basic_person() + valid_json = '{"name": "John", "age": 30}' + + def _think_tokens(text: str) -> list[int]: + return instruct_tokenizer.encode_think(ThinkChunk(thinking=text)) + + # Think + valid JSON should be accepted with json_only + reasoning + think_json_tokens = [ + *_think_tokens("Let me think about this..."), + *tokenizer.encode(valid_json, bos=False, eos=False), + tokenizer.eos_id, + ] + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=think_json_tokens, + should_fail_on=None, + case_name="think_with_json_valid", + mode="auto", + json_schema=json_schema, + reasoning=True, + ) + ) + + json_only_tokens = _encode_content(instruct_tokenizer, valid_json) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=json_only_tokens, + should_fail_on=None, + case_name="think_with_json_no_think_valid", + mode="auto", + json_schema=json_schema, + reasoning=True, + ) + ) + + think_text_tokens = [ + *_think_tokens("Let me think..."), + *tokenizer.encode("Hello world!", bos=False, eos=False), + tokenizer.eos_id, + ] + + fail_idx = len(_think_tokens("Let me think...")) + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=think_text_tokens, + should_fail_on=fail_idx, + case_name="think_with_json_rejects_text_after_think", + mode="auto", + json_schema=json_schema, + reasoning=True, + ) + ) + + return cases + + def _generate_cases(mistral_tokenizer: MistralTokenizer, factory: GrammarFactory) -> list[TestCase]: instruct_tokenizer = mistral_tokenizer.instruct_tokenizer assert isinstance(instruct_tokenizer, InstructTokenizerBase) @@ -1214,6 +1406,8 @@ def _generate_cases(mistral_tokenizer: MistralTokenizer, factory: GrammarFactory cases = _generate_general_cases(mistral_tokenizer) cases += _generate_emoji_cases(mistral_tokenizer) cases += _generate_json_schema(mistral_tokenizer, factory) + cases += _generate_json_schema_reasoning(mistral_tokenizer, factory) + cases += _generate_json_only_negative(mistral_tokenizer, factory) cases += _generate_cases_tool_calls(mistral_tokenizer) cases += _generate_single_tool_call(mistral_tokenizer) cases += _generate_strict_tool_calls(mistral_tokenizer, factory) @@ -1225,6 +1419,10 @@ def _generate_cases(mistral_tokenizer: MistralTokenizer, factory: GrammarFactory else: cases += _generate_cases_thinking_v11(mistral_tokenizer) + # v15+ supports think_with_json (thinking before JSON in json_only mode) + if tokenizer_version.supports_model_settings: + cases += _generate_json_schema_think_with_json(mistral_tokenizer, factory) + return cases @@ -1241,6 +1439,9 @@ def _get_grammar_factory(mistral_tokenizer: MistralTokenizer) -> GrammarFactory: _ALL_TOKENIZERS: list[MistralTokenizer] = [ _build_tekken_mistral_tokenizer(TokenizerVersion.v11), _build_tekken_mistral_tokenizer(TokenizerVersion.v13, add_think=True), + _build_tekken_mistral_tokenizer( + TokenizerVersion.v15, add_think=True, model_settings_builder=_V15_MODEL_SETTINGS_BUILDER + ), ] _ALL_CASES: list[TestCase] = [] @@ -1261,7 +1462,8 @@ def test_grammar(self, test_case: TestCase) -> None: if test_case.raw_lark is not None: grammar = test_case.raw_lark elif test_case.json_schema is not None and test_case.tools is None: - grammar = factory.get_lark_for_json_schema(json_schema=test_case.json_schema) + template = factory.select_jinja_template(reasoning=test_case.reasoning) + grammar = factory.get_lark_for_json_schema(template=template, json_schema=test_case.json_schema) else: template = factory.select_jinja_template(reasoning=test_case.reasoning) resolved_mode: ToolChoice @@ -1318,6 +1520,12 @@ def test_grammar(self, test_case: TestCase) -> None: (MistralTokenizer.v3(is_tekken=True), False), (_build_tekken_mistral_tokenizer(TokenizerVersion.v11), True), (_build_tekken_mistral_tokenizer(TokenizerVersion.v13, add_think=True), True), + ( + _build_tekken_mistral_tokenizer( + TokenizerVersion.v15, add_think=True, model_settings_builder=_V15_MODEL_SETTINGS_BUILDER + ), + True, + ), ], ) def test_grammar_factory_is_supported(self, tokenizer: MistralTokenizer, expected: bool) -> None: From 0a6bbda8bd8c6b2c98fcc86320d3121226ab64c6 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Mon, 30 Mar 2026 10:57:20 +0200 Subject: [PATCH 11/21] Use ids over str for special tokens --- .../guidance/grammar_factory.py | 29 +++++-- tests/guidance/test_guidance.py | 81 ++++++++++++++++--- 2 files changed, 91 insertions(+), 19 deletions(-) diff --git a/src/mistral_common/guidance/grammar_factory.py b/src/mistral_common/guidance/grammar_factory.py index 444c0b4e..6fdedd57 100644 --- a/src/mistral_common/guidance/grammar_factory.py +++ b/src/mistral_common/guidance/grammar_factory.py @@ -2,7 +2,7 @@ from enum import Enum from functools import lru_cache from pathlib import Path -from typing import Any +from typing import Any, Callable from mistral_common.guidance.tokenizer import from_mistral_tokenizer from mistral_common.imports import ( @@ -82,6 +82,7 @@ def convert_tool_calls( tools: list[Tool] | None, mode: ToolChoice, parallel_tool_calls: bool, + get_special_token_id: Callable[[str], str], ) -> str: r"""Converts tool calls to a lark grammar string. @@ -96,14 +97,20 @@ def convert_tool_calls( if mode == "none": return "" + tool_calls_token = get_special_token_id("[TOOL_CALLS]") + args_token = get_special_token_id("[ARGS]") + any_strict_true = any(tool.function.strict for tool in tools) if tools else False if not tools or not any_strict_true: if not isinstance(mode, NamedToolChoice): - grammar_tool_call = ' SAFE_WS? /.+/ SAFE_WS? %json {"type": "object"} SAFE_WS?' + grammar_tool_call = ( + f'{tool_calls_token} SAFE_WS? /.+/ {args_token} SAFE_WS? %json {{"type": "object"}} SAFE_WS?' + ) else: grammar_tool_call = ( - f' SAFE_WS? "{mode.function.name}" SAFE_WS? %json {{"type": "object"}} SAFE_WS?' + f'{tool_calls_token} SAFE_WS? "{mode.function.name}" {args_token} ' + 'SAFE_WS? %json {"type": "object"} SAFE_WS?' ) else: grammar_per_tool = [] @@ -115,11 +122,11 @@ def convert_tool_calls( for tool in tools: args = _get_tool_args_json(tool) grammar_per_tool.append( - f'( SAFE_WS? "{tool.function.name}" SAFE_WS? %json ' + f'({tool_calls_token} SAFE_WS? "{tool.function.name}" {args_token} SAFE_WS? %json ' f"{json.dumps(args, ensure_ascii=False)} SAFE_WS?)" ) grammar_tool_call = f"{' | '.join(grammar_per_tool)}" - + print(grammar_tool_call) return f"({grammar_tool_call})+" if parallel_tool_calls else grammar_tool_call @@ -160,6 +167,16 @@ def __init__(self, tokenizer: MistralTokenizer) -> None: f"got {type(self._tokenizer).__name__} {self._tokenizer.version.value}" ) self._llg_tokenizer = from_mistral_tokenizer(tokenizer) + self._special_token_map = self._build_special_token_map() + + def _build_special_token_map(self) -> dict[str, str]: + """Build a mapping from special token names to their grammar syntax.""" + return {self._tokenizer.id_to_piece(i): f"<[{i}]>" for i in range(self._tokenizer.num_special_tokens)} + + def _special_token_lark(self, token_name: str) -> str: + """Convert special token name to lark grammar syntax.""" + assert token_name in self._special_token_map, f"Unknown special token: {token_name}" + return self._special_token_map[token_name] @property def llg_tokenizer(self) -> "llg.LLTokenizer": @@ -198,7 +215,7 @@ def get_lark_from_jinja( Returns: The rendered lark grammar string. """ - fcall = convert_tool_calls(tools, mode, parallel_tool_calls) + fcall = convert_tool_calls(tools, mode, parallel_tool_calls, self._special_token_lark) json_schema_str = json.dumps(json_schema, ensure_ascii=False) if json_schema else None # NamedToolChoice forces a specific tool, which maps to "required" grammar template_mode = ToolChoiceEnum.required if isinstance(mode, NamedToolChoice) else ToolChoiceEnum(mode) diff --git a/tests/guidance/test_guidance.py b/tests/guidance/test_guidance.py index 2bbc7de0..bc9aa1b9 100644 --- a/tests/guidance/test_guidance.py +++ b/tests/guidance/test_guidance.py @@ -1545,27 +1545,52 @@ def test_get_matcher_rejects_invalid_grammar(self, v11_tekken: MistralTokenizer) factory.get_matcher("start: INVALID_RULE_REF_THAT_DOES_NOT_EXIST") +def _stub_get_special_token_id(token_name: str) -> str: + r"""Returns a stub lark grammar token for testing.""" + return f"<[{token_name}]>" + + class TestConvertToolCalls: def test_none_mode(self) -> None: - result = convert_tool_calls(tools=None, mode=ToolChoiceEnum.none, parallel_tool_calls=False) + result = convert_tool_calls( + tools=None, + mode=ToolChoiceEnum.none, + parallel_tool_calls=False, + get_special_token_id=_stub_get_special_token_id, + ) assert result == "" def test_none_mode_with_tools(self) -> None: tools = [ToolProvider.retrieve_payment_date(strict=True)] - result = convert_tool_calls(tools=tools, mode=ToolChoiceEnum.none, parallel_tool_calls=True) + result = convert_tool_calls( + tools=tools, + mode=ToolChoiceEnum.none, + parallel_tool_calls=True, + get_special_token_id=_stub_get_special_token_id, + ) assert result == "" def test_auto_mode_no_tools(self) -> None: - result = convert_tool_calls(tools=None, mode=ToolChoiceEnum.auto, parallel_tool_calls=False) - assert "" in result - assert "" in result + result = convert_tool_calls( + tools=None, + mode=ToolChoiceEnum.auto, + parallel_tool_calls=False, + get_special_token_id=_stub_get_special_token_id, + ) + assert "<[[TOOL_CALLS]]>" in result + assert "<[[ARGS]]>" in result assert "/.+/" in result assert not result.endswith(")+") def test_auto_mode_non_strict(self) -> None: tools = [ToolProvider.retrieve_payment_date(strict=False)] - result = convert_tool_calls(tools=tools, mode=ToolChoiceEnum.auto, parallel_tool_calls=False) - assert "" in result + result = convert_tool_calls( + tools=tools, + mode=ToolChoiceEnum.auto, + parallel_tool_calls=False, + get_special_token_id=_stub_get_special_token_id, + ) + assert "<[[TOOL_CALLS]]>" in result assert "/.+/" in result assert not result.endswith(")+") @@ -1574,7 +1599,12 @@ def test_auto_mode_strict(self) -> None: ToolProvider.retrieve_payment_date(strict=True), ToolProvider.retrieve_payment_status(strict=True), ] - result = convert_tool_calls(tools=tools, mode=ToolChoiceEnum.auto, parallel_tool_calls=False) + result = convert_tool_calls( + tools=tools, + mode=ToolChoiceEnum.auto, + parallel_tool_calls=False, + get_special_token_id=_stub_get_special_token_id, + ) assert '"retrieve_payment_date"' in result assert '"retrieve_payment_status"' in result assert not result.endswith(")+") @@ -1582,7 +1612,12 @@ def test_auto_mode_strict(self) -> None: def test_named_tool_choice_non_strict(self) -> None: named = NamedToolChoice(function=FunctionName(name="retrieve_payment_date")) tools = [ToolProvider.retrieve_payment_date(strict=False)] - result = convert_tool_calls(tools=tools, mode=named, parallel_tool_calls=False) + result = convert_tool_calls( + tools=tools, + mode=named, + parallel_tool_calls=False, + get_special_token_id=_stub_get_special_token_id, + ) assert '"retrieve_payment_date"' in result assert "/.+/" not in result assert not result.endswith(")+") @@ -1593,18 +1628,33 @@ def test_named_tool_choice_strict(self) -> None: ToolProvider.retrieve_payment_date(strict=True), ToolProvider.retrieve_payment_status(strict=True), ] - result = convert_tool_calls(tools=tools, mode=named, parallel_tool_calls=False) + result = convert_tool_calls( + tools=tools, + mode=named, + parallel_tool_calls=False, + get_special_token_id=_stub_get_special_token_id, + ) assert '"retrieve_payment_date"' in result assert '"retrieve_payment_status"' not in result assert not result.endswith(")+") def test_parallel_tool_calls(self) -> None: - result = convert_tool_calls(tools=None, mode=ToolChoiceEnum.auto, parallel_tool_calls=True) + result = convert_tool_calls( + tools=None, + mode=ToolChoiceEnum.auto, + parallel_tool_calls=True, + get_special_token_id=_stub_get_special_token_id, + ) assert result.startswith("(") and result.endswith(")+") def test_empty_params_strict_tool(self) -> None: tool = Tool(function=Function(name="empty_fn", parameters={}, strict=True)) - result = convert_tool_calls(tools=[tool], mode=ToolChoiceEnum.auto, parallel_tool_calls=False) + result = convert_tool_calls( + tools=[tool], + mode=ToolChoiceEnum.auto, + parallel_tool_calls=False, + get_special_token_id=_stub_get_special_token_id, + ) assert '"additionalProperties": false' in result assert '"properties": {}' in result assert not result.endswith(")+") @@ -1613,4 +1663,9 @@ def test_named_tool_not_in_strict_tools_raises(self) -> None: named = NamedToolChoice(function=FunctionName(name="non_existent_tool")) tools = [ToolProvider.retrieve_payment_date(strict=True)] with pytest.raises(StopIteration): - convert_tool_calls(tools=tools, mode=named, parallel_tool_calls=False) + convert_tool_calls( + tools=tools, + mode=named, + parallel_tool_calls=False, + get_special_token_id=_stub_get_special_token_id, + ) From d64ffa402993a2ebb82c10f3d46337b037d26d9b Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Mon, 30 Mar 2026 11:53:26 +0200 Subject: [PATCH 12/21] Verify tool is present for NamedToolChoice --- .../guidance/grammar_factory.py | 13 +++-- tests/guidance/test_guidance.py | 50 +++++-------------- 2 files changed, 21 insertions(+), 42 deletions(-) diff --git a/src/mistral_common/guidance/grammar_factory.py b/src/mistral_common/guidance/grammar_factory.py index 6fdedd57..ece2d35f 100644 --- a/src/mistral_common/guidance/grammar_factory.py +++ b/src/mistral_common/guidance/grammar_factory.py @@ -78,7 +78,7 @@ def _get_tool_args_json(tool: Tool) -> dict[str, Any]: return args -def convert_tool_calls( +def _convert_tool_calls( tools: list[Tool] | None, mode: ToolChoice, parallel_tool_calls: bool, @@ -97,6 +97,11 @@ def convert_tool_calls( if mode == "none": return "" + if isinstance(mode, NamedToolChoice) and all(mode.function.name != tool.function.name for tool in (tools or [])): + raise ValueError( + f"Tool choice requires the {mode.function.name} tool but no tools with this name has been passed." + ) + tool_calls_token = get_special_token_id("[TOOL_CALLS]") args_token = get_special_token_id("[ARGS]") @@ -126,7 +131,6 @@ def convert_tool_calls( f"{json.dumps(args, ensure_ascii=False)} SAFE_WS?)" ) grammar_tool_call = f"{' | '.join(grammar_per_tool)}" - print(grammar_tool_call) return f"({grammar_tool_call})+" if parallel_tool_calls else grammar_tool_call @@ -215,9 +219,10 @@ def get_lark_from_jinja( Returns: The rendered lark grammar string. """ - fcall = convert_tool_calls(tools, mode, parallel_tool_calls, self._special_token_lark) + fcall = _convert_tool_calls(tools, mode, parallel_tool_calls, self._special_token_lark) json_schema_str = json.dumps(json_schema, ensure_ascii=False) if json_schema else None - # NamedToolChoice forces a specific tool, which maps to "required" grammar + # NamedToolChoice forces a specific tool, which maps to "required" grammar. + # Note: _convert_tool_calls verifies that the NamedToolChoice has a valid tool. template_mode = ToolChoiceEnum.required if isinstance(mode, NamedToolChoice) else ToolChoiceEnum(mode) think_with_json = self._tokenizer.version.supports_model_settings return _cached_get_lark_from_jinja( diff --git a/tests/guidance/test_guidance.py b/tests/guidance/test_guidance.py index bc9aa1b9..2dabfd87 100644 --- a/tests/guidance/test_guidance.py +++ b/tests/guidance/test_guidance.py @@ -6,7 +6,7 @@ import pytest from pydantic import BaseModel, ConfigDict, Field, model_validator -from mistral_common.guidance.grammar_factory import GrammarFactory, convert_tool_calls +from mistral_common.guidance.grammar_factory import GrammarFactory, _convert_tool_calls from mistral_common.protocol.instruct.chunk import TextChunk, ThinkChunk from mistral_common.protocol.instruct.messages import AssistantMessage from mistral_common.protocol.instruct.normalize import get_normalizer @@ -1098,32 +1098,6 @@ def _generate_named_tool_choice(mistral_tokenizer: MistralTokenizer, factory: Gr ) ) - # 6. NamedToolChoice with non-existent tool in tools list - named_nonexistent = NamedToolChoice( - type=ToolTypes.function, - function=FunctionName(name="non_existent_tool"), - ) - nonexistent_call = [ - ToolCall( - function=FunctionCall( - name="non_existent_tool", - arguments='{"arg": "value"}', - ) - ) - ] - nonexistent_tokens = _encode_content(instruct_tokenizer, nonexistent_call) - - cases.append( - TestCase( - tokenizer=tokenizer, - tokens=nonexistent_tokens, - should_fail_on=None, - case_name="named_tool_choice_nonexistent_tool", - mode=named_nonexistent, - tools=[], - ) - ) - return cases @@ -1552,7 +1526,7 @@ def _stub_get_special_token_id(token_name: str) -> str: class TestConvertToolCalls: def test_none_mode(self) -> None: - result = convert_tool_calls( + result = _convert_tool_calls( tools=None, mode=ToolChoiceEnum.none, parallel_tool_calls=False, @@ -1562,7 +1536,7 @@ def test_none_mode(self) -> None: def test_none_mode_with_tools(self) -> None: tools = [ToolProvider.retrieve_payment_date(strict=True)] - result = convert_tool_calls( + result = _convert_tool_calls( tools=tools, mode=ToolChoiceEnum.none, parallel_tool_calls=True, @@ -1571,7 +1545,7 @@ def test_none_mode_with_tools(self) -> None: assert result == "" def test_auto_mode_no_tools(self) -> None: - result = convert_tool_calls( + result = _convert_tool_calls( tools=None, mode=ToolChoiceEnum.auto, parallel_tool_calls=False, @@ -1584,7 +1558,7 @@ def test_auto_mode_no_tools(self) -> None: def test_auto_mode_non_strict(self) -> None: tools = [ToolProvider.retrieve_payment_date(strict=False)] - result = convert_tool_calls( + result = _convert_tool_calls( tools=tools, mode=ToolChoiceEnum.auto, parallel_tool_calls=False, @@ -1599,7 +1573,7 @@ def test_auto_mode_strict(self) -> None: ToolProvider.retrieve_payment_date(strict=True), ToolProvider.retrieve_payment_status(strict=True), ] - result = convert_tool_calls( + result = _convert_tool_calls( tools=tools, mode=ToolChoiceEnum.auto, parallel_tool_calls=False, @@ -1612,7 +1586,7 @@ def test_auto_mode_strict(self) -> None: def test_named_tool_choice_non_strict(self) -> None: named = NamedToolChoice(function=FunctionName(name="retrieve_payment_date")) tools = [ToolProvider.retrieve_payment_date(strict=False)] - result = convert_tool_calls( + result = _convert_tool_calls( tools=tools, mode=named, parallel_tool_calls=False, @@ -1628,7 +1602,7 @@ def test_named_tool_choice_strict(self) -> None: ToolProvider.retrieve_payment_date(strict=True), ToolProvider.retrieve_payment_status(strict=True), ] - result = convert_tool_calls( + result = _convert_tool_calls( tools=tools, mode=named, parallel_tool_calls=False, @@ -1639,7 +1613,7 @@ def test_named_tool_choice_strict(self) -> None: assert not result.endswith(")+") def test_parallel_tool_calls(self) -> None: - result = convert_tool_calls( + result = _convert_tool_calls( tools=None, mode=ToolChoiceEnum.auto, parallel_tool_calls=True, @@ -1649,7 +1623,7 @@ def test_parallel_tool_calls(self) -> None: def test_empty_params_strict_tool(self) -> None: tool = Tool(function=Function(name="empty_fn", parameters={}, strict=True)) - result = convert_tool_calls( + result = _convert_tool_calls( tools=[tool], mode=ToolChoiceEnum.auto, parallel_tool_calls=False, @@ -1662,8 +1636,8 @@ def test_empty_params_strict_tool(self) -> None: def test_named_tool_not_in_strict_tools_raises(self) -> None: named = NamedToolChoice(function=FunctionName(name="non_existent_tool")) tools = [ToolProvider.retrieve_payment_date(strict=True)] - with pytest.raises(StopIteration): - convert_tool_calls( + with pytest.raises(ValueError): + _convert_tool_calls( tools=tools, mode=named, parallel_tool_calls=False, From eefa0640e89dab03ff8d010b5cbf2016f79a8b61 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Mon, 30 Mar 2026 11:56:00 +0200 Subject: [PATCH 13/21] Remove defaults --- src/mistral_common/guidance/grammar_factory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mistral_common/guidance/grammar_factory.py b/src/mistral_common/guidance/grammar_factory.py index ece2d35f..650c4f23 100644 --- a/src/mistral_common/guidance/grammar_factory.py +++ b/src/mistral_common/guidance/grammar_factory.py @@ -42,8 +42,8 @@ def _cached_get_lark_from_jinja( fcall: str, json_schema_str: str | None, parallel_tool_calls: bool, - json_only: bool = False, - think_with_json: bool = False, + json_only: bool, + think_with_json: bool, ) -> str: jinja_template = Template(template) lark_grammar = jinja_template.render( From 7d843bdc8eefb3a63c7a21b9ed78467ba7a017fd Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Mon, 30 Mar 2026 12:12:09 +0200 Subject: [PATCH 14/21] Improve input validation --- .../guidance/grammar_factory.py | 18 ++- tests/guidance/test_guidance.py | 152 ++++++++++-------- 2 files changed, 100 insertions(+), 70 deletions(-) diff --git a/src/mistral_common/guidance/grammar_factory.py b/src/mistral_common/guidance/grammar_factory.py index 650c4f23..e4d86356 100644 --- a/src/mistral_common/guidance/grammar_factory.py +++ b/src/mistral_common/guidance/grammar_factory.py @@ -25,6 +25,15 @@ JINJA_DIR = Path(__file__).parent / "data" +def _validate_mode_and_tools(mode: ToolChoice, tools: list[Tool] | None): + if isinstance(mode, NamedToolChoice) and all(mode.function.name != tool.function.name for tool in (tools or [])): + raise ValueError( + f"Tool choice requires the {mode.function.name} tool but no tools with this name has been passed." + ) + elif mode in [ToolChoiceEnum.any, ToolChoiceEnum.required] and not tools: + raise ValueError(f"When {mode=} please ensure to pass tools, got {tools=}.") + + @lru_cache() def _cached_get_jinja_template(tokenizer_version: TokenizerVersion, reasoning: bool) -> str: if tokenizer_version < TokenizerVersion.v13: @@ -97,11 +106,6 @@ def _convert_tool_calls( if mode == "none": return "" - if isinstance(mode, NamedToolChoice) and all(mode.function.name != tool.function.name for tool in (tools or [])): - raise ValueError( - f"Tool choice requires the {mode.function.name} tool but no tools with this name has been passed." - ) - tool_calls_token = get_special_token_id("[TOOL_CALLS]") args_token = get_special_token_id("[ARGS]") @@ -219,10 +223,12 @@ def get_lark_from_jinja( Returns: The rendered lark grammar string. """ + # Verifies that the NamedToolChoice has a valid tool and "any", "required" has tools. + _validate_mode_and_tools(mode=mode, tools=tools) + fcall = _convert_tool_calls(tools, mode, parallel_tool_calls, self._special_token_lark) json_schema_str = json.dumps(json_schema, ensure_ascii=False) if json_schema else None # NamedToolChoice forces a specific tool, which maps to "required" grammar. - # Note: _convert_tool_calls verifies that the NamedToolChoice has a valid tool. template_mode = ToolChoiceEnum.required if isinstance(mode, NamedToolChoice) else ToolChoiceEnum(mode) think_with_json = self._tokenizer.version.supports_model_settings return _cached_get_lark_from_jinja( diff --git a/tests/guidance/test_guidance.py b/tests/guidance/test_guidance.py index 2dabfd87..d80eed6c 100644 --- a/tests/guidance/test_guidance.py +++ b/tests/guidance/test_guidance.py @@ -413,7 +413,7 @@ def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[Test ( "single_fcall", [ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}'))], - {"auto": None, "any": None, "none": 0, "required": None}, + {"auto": None, "none": 0}, ), ( "multi_fcall", @@ -422,7 +422,7 @@ def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[Test ToolCall(function=FunctionCall(name="hello_1", arguments='{"arg1": "val1", "arg2": "val2"}')), ToolCall(function=FunctionCall(name="hello_2_3", arguments='{"arg1": "val1", "arg2": "val2"}')), ], - {"auto": None, "any": None, "none": 0, "required": None}, + {"auto": None, "none": 0}, ), ( "emoji_fcall", @@ -431,7 +431,7 @@ def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[Test function=FunctionCall(name="he🧦🧦o", arguments='{"arg1": "🐱", "arg2": "🐢", "arg🧦": "🧦"}'), ) ], - {"auto": None, "any": None, "none": 0, "required": None}, + {"auto": None, "none": 0}, ), ( "pretty_printed_args", @@ -443,12 +443,12 @@ def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[Test ), ) ], - {"auto": None, "any": None, "none": 0, "required": None}, + {"auto": None, "none": 0}, ), ( "japanese_fcall", [ToolCall(function=FunctionCall(name="こんにけは", arguments='{"こん": "にけは"}'))], - {"auto": None, "any": None, "none": 0, "required": None}, + {"auto": None, "none": 0}, ), ] for case_name, content, valid_for in items: @@ -465,9 +465,9 @@ def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[Test ) if tokenizer.version < TokenizerVersion.v13: - reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "any": 0, "none": 0, "required": 0} + reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "none": 0} else: - reasoning_valid_for = {"auto": None, "any": None, "none": 0, "required": None} + reasoning_valid_for = {"auto": None, "none": 0} for mode, should_fail_on in reasoning_valid_for.items(): cases.append( TestCase( @@ -491,7 +491,7 @@ def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[Test *tokenizer.encode('{"a', bos=False, eos=False), tokenizer.eos_id, ], - {"auto": -1, "any": -1, "none": 0, "required": -1}, + {"auto": -1, "none": 0}, ), ( "fcall_missing_args", @@ -501,7 +501,7 @@ def _generate_cases_tool_calls(mistral_tokenizer: MistralTokenizer) -> list[Test tokenizer.get_special_token("[ARGS]"), tokenizer.get_special_token("[TOOL_CALLS]"), ], - {"auto": -1, "any": -1, "none": 0, "required": -1}, + {"auto": -1, "none": 0}, ), ] @@ -532,7 +532,7 @@ def _generate_cases_text_and_tool_calls(mistral_tokenizer: MistralTokenizer) -> tokens = _encode_content(instruct_tokenizer, content) text_len = len(tokenizer.encode("Hello!", bos=False, eos=False)) - valid_for: dict[Mode, int | None] = {"auto": None, "any": 0, "none": text_len, "required": None} + valid_for: dict[Mode, int | None] = {"auto": None, "none": text_len} for mode, should_fail_on in valid_for.items(): cases.append( @@ -546,9 +546,9 @@ def _generate_cases_text_and_tool_calls(mistral_tokenizer: MistralTokenizer) -> ) if tokenizer.version < TokenizerVersion.v13: - reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "any": 0, "none": 0, "required": 0} + reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "none": 0} else: - reasoning_valid_for = {"auto": None, "any": None, "none": text_len, "required": None} + reasoning_valid_for = {"auto": None, "none": text_len} for mode, should_fail_on in reasoning_valid_for.items(): cases.append( TestCase( @@ -571,31 +571,29 @@ def _generate_cases_thinking_v11(mistral_tokenizer: MistralTokenizer) -> list[Te assert isinstance(instruct_tokenizer, InstructTokenizerBase) cases: list[TestCase] = [] + # Note: "any"/"required" modes require tools, so only "auto"/"none" are used here (tools=None). thinks: list[tuple[str, list[Any], dict[Mode, int | None]]] = [ ( "force_think", [TextChunk(text="Hello world!")], - {"auto": 0, "any": 0, "none": 0, "required": 0}, + {"auto": 0, "none": 0}, ), ( "think_without_response", [TextChunk(text="Hello!")], - {"auto": -1, "any": -1, "none": -1, "required": -1}, + {"auto": -1, "none": -1}, ), ( "unclosed_think", [TextChunk(text="Hello!")], - {"auto": -1, "any": -1, "none": -1, "required": -1}, + {"auto": -1, "none": -1}, ), ( "plain_think_with_response", [TextChunk(text="Hello!World!")], { "auto": None, - # any/required: think fcalls β€” after think, "World!" doesn't match fcalls - "any": len(tokenizer.encode("Hello!", bos=False, eos=False)), "none": None, - "required": len(tokenizer.encode("Hello!", bos=False, eos=False)), }, ), ( @@ -606,9 +604,7 @@ def _generate_cases_thinking_v11(mistral_tokenizer: MistralTokenizer) -> list[Te ], { "auto": None, - "any": None, "none": len(tokenizer.encode("Hello!", bos=False, eos=False)), - "required": None, }, ), ( @@ -620,10 +616,7 @@ def _generate_cases_thinking_v11(mistral_tokenizer: MistralTokenizer) -> list[Te { # auto: think (content | fcalls) β€” picks content for "Ho!", then tool call rejected "auto": len(tokenizer.encode("Hello!Ho!", bos=False, eos=False)), - # any/required: think fcalls β€” after think, "Ho!" doesn't match fcalls - "any": len(tokenizer.encode("Hello!", bos=False, eos=False)), "none": len(tokenizer.encode("Hello!Ho!", bos=False, eos=False)), - "required": len(tokenizer.encode("Hello!", bos=False, eos=False)), }, ), ] @@ -656,21 +649,22 @@ def _think_tokens(text: str) -> list[int]: assert isinstance(instruct_tokenizer, InstructTokenizerV13) return instruct_tokenizer.encode_think(ThinkChunk(thinking=text)) + # Note: "any"/"required" modes require tools, so only "auto"/"none" are used here (tools=None). thinks: list[tuple[str, list[Any], dict[Mode, int | None]]] = [ ( "plain_text", [TextChunk(text="Hello world!")], - {"auto": None, "any": -1, "none": None, "required": -1}, + {"auto": None, "none": None}, ), ( "plain_think", [ThinkChunk(thinking="Hello!")], - {"auto": -1, "any": -1, "none": -1, "required": -1}, + {"auto": -1, "none": -1}, ), ( "plain_think_with_response", [ThinkChunk(thinking="Hello!"), TextChunk(text="World!")], - {"auto": None, "any": -1, "none": None, "required": -1}, + {"auto": None, "none": None}, ), ( "think_with_tool_call", @@ -680,9 +674,7 @@ def _think_tokens(text: str) -> list[int]: ], { "auto": None, - "any": None, "none": len(_think_tokens("Hello!")), - "required": None, }, ), ( @@ -694,10 +686,7 @@ def _think_tokens(text: str) -> list[int]: ], { "auto": None, - "any": None, "none": len(_think_tokens("Hello!")) + len(tokenizer.encode("World!", bos=False, eos=False)), - # required: think? content? fcalls β€” think+content+fcalls all present, passes - "required": None, }, ), ] @@ -725,17 +714,17 @@ def _generate_single_tool_call(mistral_tokenizer: MistralTokenizer) -> list[Test cases: list[TestCase] = [] single_call = [ToolCall(function=FunctionCall(name="hello", arguments='{"arg1": "val1", "arg2": "val2"}'))] single_tokens = _encode_content(instruct_tokenizer, single_call) - for mode in _AUTO_ANY_REQUIRED: - cases.append( - TestCase( - tokenizer=tokenizer, - tokens=single_tokens, - should_fail_on=None, - case_name="single_tool_call", - mode=mode, - parallel_tool_calls=False, - ) + # Only "auto" can be used without tools; "any"/"required" require tools. + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=single_tokens, + should_fail_on=None, + case_name="single_tool_call", + mode="auto", + parallel_tool_calls=False, ) + ) cases.append( TestCase( tokenizer=tokenizer, @@ -748,9 +737,9 @@ def _generate_single_tool_call(mistral_tokenizer: MistralTokenizer) -> list[Test ) if tokenizer.version < TokenizerVersion.v13: - reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "any": 0, "none": 0, "required": 0} + reasoning_valid_for: dict[Mode, int | None] = {"auto": 0, "none": 0} else: - reasoning_valid_for = {"auto": None, "any": None, "none": 0, "required": None} + reasoning_valid_for = {"auto": None, "none": 0} for mode, should_fail_on in reasoning_valid_for.items(): cases.append( TestCase( @@ -772,17 +761,17 @@ def _generate_single_tool_call(mistral_tokenizer: MistralTokenizer) -> list[Test single_tokens_with_eos = _encode_content(instruct_tokenizer, [multi_calls[0]]) fail_idx = len(single_tokens_with_eos) - 1 - for mode in _AUTO_ANY_REQUIRED: - cases.append( - TestCase( - tokenizer=tokenizer, - tokens=multi_tokens, - should_fail_on=fail_idx, - case_name="multi_tool_call_disallowed", - mode=mode, - parallel_tool_calls=False, - ) + # Only "auto" can be used without tools; "any"/"required" require tools. + cases.append( + TestCase( + tokenizer=tokenizer, + tokens=multi_tokens, + should_fail_on=fail_idx, + case_name="multi_tool_call_disallowed", + mode="auto", + parallel_tool_calls=False, ) + ) return cases @@ -1518,6 +1507,52 @@ def test_get_matcher_rejects_invalid_grammar(self, v11_tekken: MistralTokenizer) with pytest.raises(ValueError, match="Invalid grammar"): factory.get_matcher("start: INVALID_RULE_REF_THAT_DOES_NOT_EXIST") + @pytest.mark.parametrize("mode", [ToolChoiceEnum.any, ToolChoiceEnum.required]) + def test_get_lark_rejects_any_required_without_tools( + self, v11_tekken: MistralTokenizer, mode: ToolChoiceEnum + ) -> None: + factory = GrammarFactory(v11_tekken) + template = factory.select_jinja_template(reasoning=False) + with pytest.raises(ValueError, match="please ensure to pass tools"): + factory.get_lark_from_jinja( + template=template, mode=mode, tools=None, json_schema=None, parallel_tool_calls=True + ) + + def test_get_lark_rejects_named_tool_not_in_tools(self, v11_tekken: MistralTokenizer) -> None: + factory = GrammarFactory(v11_tekken) + template = factory.select_jinja_template(reasoning=False) + named = NamedToolChoice(function=FunctionName(name="non_existent")) + tools = [ToolProvider.retrieve_payment_date(strict=True)] + with pytest.raises(ValueError, match="no tools with this name"): + factory.get_lark_from_jinja( + template=template, mode=named, tools=tools, json_schema=None, parallel_tool_calls=True + ) + + @pytest.mark.parametrize("mode", [ToolChoiceEnum.any, ToolChoiceEnum.required]) + @pytest.mark.parametrize("tools", [None, []]) + def test_any_required_without_tools_raises( + self, v11_tekken: MistralTokenizer, mode: ToolChoiceEnum, tools: list[Tool] | None + ) -> None: + factory = GrammarFactory(v11_tekken) + template = factory.select_jinja_template(reasoning=False) + with pytest.raises(ValueError, match="please ensure to pass tools"): + factory.get_lark_from_jinja( + template=template, tools=tools, mode=mode, json_schema=None, parallel_tool_calls=True + ) + + @pytest.mark.parametrize("tools", [None, [], [Tool(function=Function(name="existing_tool", parameters={}))]]) + def test_named_wrong_tools_raises(self, v11_tekken: MistralTokenizer, tools: list[Tool] | None) -> None: + factory = GrammarFactory(v11_tekken) + template = factory.select_jinja_template(reasoning=False) + with pytest.raises(ValueError, match="no tools with this name"): + factory.get_lark_from_jinja( + template=template, + tools=tools, + mode=NamedToolChoice(function=FunctionName(name="non_existent_tool")), + json_schema=None, + parallel_tool_calls=True, + ) + def _stub_get_special_token_id(token_name: str) -> str: r"""Returns a stub lark grammar token for testing.""" @@ -1632,14 +1667,3 @@ def test_empty_params_strict_tool(self) -> None: assert '"additionalProperties": false' in result assert '"properties": {}' in result assert not result.endswith(")+") - - def test_named_tool_not_in_strict_tools_raises(self) -> None: - named = NamedToolChoice(function=FunctionName(name="non_existent_tool")) - tools = [ToolProvider.retrieve_payment_date(strict=True)] - with pytest.raises(ValueError): - _convert_tool_calls( - tools=tools, - mode=named, - parallel_tool_calls=False, - get_special_token_id=_stub_get_special_token_id, - ) From 35c49b68373972d9e0f776aa2a626dba6eaeb11b Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Mon, 30 Mar 2026 12:13:19 +0200 Subject: [PATCH 15/21] mypy --- src/mistral_common/guidance/grammar_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mistral_common/guidance/grammar_factory.py b/src/mistral_common/guidance/grammar_factory.py index e4d86356..64b809a7 100644 --- a/src/mistral_common/guidance/grammar_factory.py +++ b/src/mistral_common/guidance/grammar_factory.py @@ -25,7 +25,7 @@ JINJA_DIR = Path(__file__).parent / "data" -def _validate_mode_and_tools(mode: ToolChoice, tools: list[Tool] | None): +def _validate_mode_and_tools(mode: ToolChoice, tools: list[Tool] | None) -> None: if isinstance(mode, NamedToolChoice) and all(mode.function.name != tool.function.name for tool in (tools or [])): raise ValueError( f"Tool choice requires the {mode.function.name} tool but no tools with this name has been passed." From 95a44bdbd43d61350c2deb9c6c1b52188b12d5b0 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 31 Mar 2026 17:27:09 +0200 Subject: [PATCH 16/21] Solve Patrick comments + pass think token id for v13+ --- .../guidance/data/think_grammar.lark.jinja | 2 +- .../guidance/grammar_factory.py | 50 +++++++++++-------- 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/src/mistral_common/guidance/data/think_grammar.lark.jinja b/src/mistral_common/guidance/data/think_grammar.lark.jinja index f49e417c..845285c0 100644 --- a/src/mistral_common/guidance/data/think_grammar.lark.jinja +++ b/src/mistral_common/guidance/data/think_grammar.lark.jinja @@ -26,6 +26,6 @@ fcalls: content? {{ fcall }} content: (/(.|\n)+/)+ -think: content +think: {{ begin_think_token }} content {{ end_think_token }} SAFE_WS: /[ \t\r\n]{1,8}/ diff --git a/src/mistral_common/guidance/grammar_factory.py b/src/mistral_common/guidance/grammar_factory.py index 64b809a7..8067ba4a 100644 --- a/src/mistral_common/guidance/grammar_factory.py +++ b/src/mistral_common/guidance/grammar_factory.py @@ -12,7 +12,7 @@ is_llguidance_installed, ) from mistral_common.protocol.instruct.tool_calls import NamedToolChoice, Tool, ToolChoice, ToolChoiceEnum -from mistral_common.tokens.tokenizers.base import TokenizerVersion +from mistral_common.tokens.tokenizers.base import SpecialTokens, TokenizerVersion from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.tekken import is_tekkenizer @@ -36,12 +36,14 @@ def _validate_mode_and_tools(mode: ToolChoice, tools: list[Tool] | None) -> None @lru_cache() def _cached_get_jinja_template(tokenizer_version: TokenizerVersion, reasoning: bool) -> str: - if tokenizer_version < TokenizerVersion.v13: - jinja_key = _GrammarVariant.plain_think if reasoning else _GrammarVariant.base + if not reasoning: + jinja_key = _GrammarVariant.base + elif tokenizer_version < TokenizerVersion.v13: + jinja_key = _GrammarVariant.plain_think else: - jinja_key = _GrammarVariant.think if reasoning else _GrammarVariant.base - jinja_path = JINJA_PATHS[jinja_key] - return jinja_path.read_text(encoding="utf-8") + jinja_key = _GrammarVariant.think + + return JINJA_PATHS[jinja_key].read_text(encoding="utf-8") @lru_cache() @@ -53,6 +55,8 @@ def _cached_get_lark_from_jinja( parallel_tool_calls: bool, json_only: bool, think_with_json: bool, + begin_think_token: str | None, + end_think_token: str | None, ) -> str: jinja_template = Template(template) lark_grammar = jinja_template.render( @@ -62,6 +66,8 @@ def _cached_get_lark_from_jinja( parallel_tool_calls=parallel_tool_calls, json_only=json_only, think_with_json=think_with_json, + begin_think_token=begin_think_token, + end_think_token=end_think_token, ) return lark_grammar @@ -82,9 +88,7 @@ class _GrammarVariant(str, Enum): def _get_tool_args_json(tool: Tool) -> dict[str, Any]: r"""Returns the JSON schema for a tool's arguments.""" args = tool.function.parameters if tool.function.strict else {"type": "object"} - if args == {}: - args = {"type": "object", "properties": {}, "additionalProperties": False} - return args + return args or {"type": "object", "properties": {}, "additionalProperties": False} def _convert_tool_calls( @@ -99,6 +103,7 @@ def _convert_tool_calls( tools: The list of tools available. mode: The tool choice mode. parallel_tool_calls: Whether parallel tool calls are allowed. + get_special_token_id: Callable that maps a special token name to its lark grammar syntax. Returns: The lark grammar string for tool calls. @@ -106,21 +111,16 @@ def _convert_tool_calls( if mode == "none": return "" - tool_calls_token = get_special_token_id("[TOOL_CALLS]") - args_token = get_special_token_id("[ARGS]") + tool_calls_token = get_special_token_id(SpecialTokens.tool_calls.value) + args_token = get_special_token_id(SpecialTokens.args.value) any_strict_true = any(tool.function.strict for tool in tools) if tools else False if not tools or not any_strict_true: - if not isinstance(mode, NamedToolChoice): - grammar_tool_call = ( - f'{tool_calls_token} SAFE_WS? /.+/ {args_token} SAFE_WS? %json {{"type": "object"}} SAFE_WS?' - ) - else: - grammar_tool_call = ( - f'{tool_calls_token} SAFE_WS? "{mode.function.name}" {args_token} ' - 'SAFE_WS? %json {"type": "object"} SAFE_WS?' - ) + tool_name = f'"{mode.function.name}"' if isinstance(mode, NamedToolChoice) else "/.+/" + grammar_tool_call = ( + f'{tool_calls_token} SAFE_WS? {tool_name} {args_token} SAFE_WS? %json {{"type": "object"}} SAFE_WS?' + ) else: grammar_per_tool = [] tools = ( @@ -186,6 +186,10 @@ def _special_token_lark(self, token_name: str) -> str: assert token_name in self._special_token_map, f"Unknown special token: {token_name}" return self._special_token_map[token_name] + def _get_optional_special_token_lark(self, token_name: str) -> str | None: + r"""Returns lark grammar syntax for a special token, or `None` if absent.""" + return self._special_token_map.get(token_name) + @property def llg_tokenizer(self) -> "llg.LLTokenizer": return self._llg_tokenizer @@ -231,6 +235,10 @@ def get_lark_from_jinja( # NamedToolChoice forces a specific tool, which maps to "required" grammar. template_mode = ToolChoiceEnum.required if isinstance(mode, NamedToolChoice) else ToolChoiceEnum(mode) think_with_json = self._tokenizer.version.supports_model_settings + + begin_think_token = self._get_optional_special_token_lark(SpecialTokens.begin_think.value) + end_think_token = self._get_optional_special_token_lark(SpecialTokens.end_think.value) + return _cached_get_lark_from_jinja( template=template, mode=template_mode.value, @@ -239,6 +247,8 @@ def get_lark_from_jinja( parallel_tool_calls=parallel_tool_calls, json_only=json_only, think_with_json=think_with_json, + begin_think_token=begin_think_token, + end_think_token=end_think_token, ) def get_lark_for_json_schema(self, template: str, json_schema: dict[str, Any]) -> str: From c3c3aeed8e74ae6ae1c520a3a30d6ca92ad4ba8e Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 31 Mar 2026 18:11:56 +0200 Subject: [PATCH 17/21] Factorize code --- .../guidance/grammar_factory.py | 36 ++++++++++++------- tests/guidance/test_guidance.py | 20 +++++++++++ 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/src/mistral_common/guidance/grammar_factory.py b/src/mistral_common/guidance/grammar_factory.py index 8067ba4a..087d49d3 100644 --- a/src/mistral_common/guidance/grammar_factory.py +++ b/src/mistral_common/guidance/grammar_factory.py @@ -85,6 +85,9 @@ class _GrammarVariant(str, Enum): } +_TOOL_CALL_GRAMMAR = "{tool_calls_token} SAFE_WS? {tool_name} {args_token} SAFE_WS? %json {args_json} SAFE_WS?" + + def _get_tool_args_json(tool: Tool) -> dict[str, Any]: r"""Returns the JSON schema for a tool's arguments.""" args = tool.function.parameters if tool.function.strict else {"type": "object"} @@ -118,23 +121,32 @@ def _convert_tool_calls( if not tools or not any_strict_true: tool_name = f'"{mode.function.name}"' if isinstance(mode, NamedToolChoice) else "/.+/" - grammar_tool_call = ( - f'{tool_calls_token} SAFE_WS? {tool_name} {args_token} SAFE_WS? %json {{"type": "object"}} SAFE_WS?' - ) + tool_entries = [(tool_name, '{"type": "object"}')] else: - grammar_per_tool = [] - tools = ( + filtered_tools = ( [next(tool for tool in tools if tool.function.name == mode.function.name)] if isinstance(mode, NamedToolChoice) else tools ) - for tool in tools: - args = _get_tool_args_json(tool) - grammar_per_tool.append( - f'({tool_calls_token} SAFE_WS? "{tool.function.name}" {args_token} SAFE_WS? %json ' - f"{json.dumps(args, ensure_ascii=False)} SAFE_WS?)" - ) - grammar_tool_call = f"{' | '.join(grammar_per_tool)}" + tool_entries = [ + (f'"{tool.function.name}"', json.dumps(_get_tool_args_json(tool), ensure_ascii=False)) + for tool in filtered_tools + ] + + grammar_parts = [ + _TOOL_CALL_GRAMMAR.format( + tool_calls_token=tool_calls_token, + args_token=args_token, + tool_name=name, + args_json=args_json, + ) + for name, args_json in tool_entries + ] + + grammar_tool_call = ( + " | ".join(f"({part})" for part in grammar_parts) if len(grammar_parts) > 1 else grammar_parts[0] + ) + return f"({grammar_tool_call})+" if parallel_tool_calls else grammar_tool_call diff --git a/tests/guidance/test_guidance.py b/tests/guidance/test_guidance.py index d80eed6c..8bd22185 100644 --- a/tests/guidance/test_guidance.py +++ b/tests/guidance/test_guidance.py @@ -1589,6 +1589,7 @@ def test_auto_mode_no_tools(self) -> None: assert "<[[TOOL_CALLS]]>" in result assert "<[[ARGS]]>" in result assert "/.+/" in result + assert not result.startswith("(") assert not result.endswith(")+") def test_auto_mode_non_strict(self) -> None: @@ -1601,6 +1602,7 @@ def test_auto_mode_non_strict(self) -> None: ) assert "<[[TOOL_CALLS]]>" in result assert "/.+/" in result + assert not result.startswith("(") assert not result.endswith(")+") def test_auto_mode_strict(self) -> None: @@ -1616,6 +1618,21 @@ def test_auto_mode_strict(self) -> None: ) assert '"retrieve_payment_date"' in result assert '"retrieve_payment_status"' in result + assert " | " in result + assert result.count("(") >= 2 + assert not result.endswith(")+") + + def test_auto_mode_single_strict(self) -> None: + tools = [ToolProvider.retrieve_payment_date(strict=True)] + result = _convert_tool_calls( + tools=tools, + mode=ToolChoiceEnum.auto, + parallel_tool_calls=False, + get_special_token_id=_stub_get_special_token_id, + ) + assert '"retrieve_payment_date"' in result + assert " | " not in result + assert not result.startswith("(") assert not result.endswith(")+") def test_named_tool_choice_non_strict(self) -> None: @@ -1645,6 +1662,8 @@ def test_named_tool_choice_strict(self) -> None: ) assert '"retrieve_payment_date"' in result assert '"retrieve_payment_status"' not in result + assert " | " not in result + assert not result.startswith("(") assert not result.endswith(")+") def test_parallel_tool_calls(self) -> None: @@ -1666,4 +1685,5 @@ def test_empty_params_strict_tool(self) -> None: ) assert '"additionalProperties": false' in result assert '"properties": {}' in result + assert not result.startswith("(") assert not result.endswith(")+") From e27aa4be8def10333245ca5803dcf8fb5f8c274d Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 31 Mar 2026 18:43:25 +0200 Subject: [PATCH 18/21] Remove useless *args **kwargs --- src/mistral_common/guidance/tokenizer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/mistral_common/guidance/tokenizer.py b/src/mistral_common/guidance/tokenizer.py index a37f789d..005e2ba8 100644 --- a/src/mistral_common/guidance/tokenizer.py +++ b/src/mistral_common/guidance/tokenizer.py @@ -1,5 +1,4 @@ import re -from typing import Any from mistral_common.imports import assert_llguidance_installed, is_llguidance_installed from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, Tokenizer @@ -80,13 +79,11 @@ def __init__(self, tokenizer: Tokenizer) -> None: f"{len(self._special_token_ids)}" ) - def __call__(self, s: str, *args: Any, **kwargs: Any) -> list[int]: + def __call__(self, s: str) -> list[int]: r"""Tokenizes a string into token ids. Args: s: The string to tokenize. - *args: Additional positional arguments (ignored). - **kwargs: Additional keyword arguments (ignored). Returns: The list of token ids. From c9ab729f7e8e95c6e6cdf85d58f1212e1840b18d Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 31 Mar 2026 18:56:48 +0200 Subject: [PATCH 19/21] Remove LLM matcher --- src/mistral_common/guidance/grammar_factory.py | 17 ----------------- tests/guidance/test_guidance.py | 12 ++++-------- 2 files changed, 4 insertions(+), 25 deletions(-) diff --git a/src/mistral_common/guidance/grammar_factory.py b/src/mistral_common/guidance/grammar_factory.py index 087d49d3..8565452b 100644 --- a/src/mistral_common/guidance/grammar_factory.py +++ b/src/mistral_common/guidance/grammar_factory.py @@ -281,20 +281,3 @@ def get_lark_for_json_schema(self, template: str, json_schema: dict[str, Any]) - parallel_tool_calls=True, json_only=True, ) - - def get_matcher(self, lark: str) -> "llg.LLMatcher": - r"""Creates an LLMatcher from a lark grammar string. - - Args: - lark: The lark grammar string. - - Returns: - The LLMatcher instance. - - Raises: - ValueError: If the grammar is invalid. - """ - error = llg.LLMatcher.validate_grammar(lark) - if error: - raise ValueError(f"Invalid grammar: {error}") - return llg.LLMatcher(self._llg_tokenizer, lark) diff --git a/tests/guidance/test_guidance.py b/tests/guidance/test_guidance.py index 8bd22185..70d9c856 100644 --- a/tests/guidance/test_guidance.py +++ b/tests/guidance/test_guidance.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Any, Literal +import llguidance as llg import pytest from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -284,7 +285,7 @@ def _find_first_rejection( grammar = factory.get_lark_from_jinja( template=template, mode=mode, tools=tools, json_schema=None, parallel_tool_calls=True ) - matcher = factory.get_matcher(grammar) + matcher = llg.LLMatcher(factory.llg_tokenizer, grammar) for i, token in enumerate(tokens): if not matcher.consume_token(token): return i @@ -1110,7 +1111,7 @@ def _find_first_json_schema_rejection( """ template = factory.select_jinja_template(reasoning=False) grammar = factory.get_lark_for_json_schema(template=template, json_schema=json_schema) - matcher = factory.get_matcher(grammar) + matcher = llg.LLMatcher(factory.llg_tokenizer, grammar) for i, token in enumerate(tokens): if not matcher.consume_token(token): return i @@ -1442,7 +1443,7 @@ def test_grammar(self, test_case: TestCase) -> None: parallel_tool_calls=test_case.parallel_tool_calls, ) - matcher = factory.get_matcher(grammar) + matcher = llg.LLMatcher(factory.llg_tokenizer, grammar) assert is_tekkenizer(test_case.tokenizer) debug_tokens = [_token_debug_repr(test_case.tokenizer, t) for t in test_case.tokens] @@ -1502,11 +1503,6 @@ def test_grammar_factory_init_rejects_unsupported(self, tokenizer: MistralTokeni with pytest.raises(ValueError, match="Guidance requires a Tekken tokenizer with version >= v11"): GrammarFactory(tokenizer) - def test_get_matcher_rejects_invalid_grammar(self, v11_tekken: MistralTokenizer) -> None: - factory = GrammarFactory(v11_tekken) - with pytest.raises(ValueError, match="Invalid grammar"): - factory.get_matcher("start: INVALID_RULE_REF_THAT_DOES_NOT_EXIST") - @pytest.mark.parametrize("mode", [ToolChoiceEnum.any, ToolChoiceEnum.required]) def test_get_lark_rejects_any_required_without_tools( self, v11_tekken: MistralTokenizer, mode: ToolChoiceEnum From d5a684e4ddee9d0a5724ed382119c6660837d366 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 31 Mar 2026 21:05:51 +0200 Subject: [PATCH 20/21] Use Enum instead of string --- src/mistral_common/guidance/grammar_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mistral_common/guidance/grammar_factory.py b/src/mistral_common/guidance/grammar_factory.py index 8565452b..5286633a 100644 --- a/src/mistral_common/guidance/grammar_factory.py +++ b/src/mistral_common/guidance/grammar_factory.py @@ -111,7 +111,7 @@ def _convert_tool_calls( Returns: The lark grammar string for tool calls. """ - if mode == "none": + if mode == ToolChoiceEnum.none: return "" tool_calls_token = get_special_token_id(SpecialTokens.tool_calls.value) From 5b8bed3d3ddc2cf925941f6260238e645c1e10f3 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 31 Mar 2026 21:11:04 +0200 Subject: [PATCH 21/21] Apply suggestions from code review Co-authored-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> --- src/mistral_common/guidance/grammar_factory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mistral_common/guidance/grammar_factory.py b/src/mistral_common/guidance/grammar_factory.py index 5286633a..c4f251aa 100644 --- a/src/mistral_common/guidance/grammar_factory.py +++ b/src/mistral_common/guidance/grammar_factory.py @@ -190,11 +190,11 @@ def __init__(self, tokenizer: MistralTokenizer) -> None: self._special_token_map = self._build_special_token_map() def _build_special_token_map(self) -> dict[str, str]: - """Build a mapping from special token names to their grammar syntax.""" + r"""Build a mapping from special token names to their grammar syntax.""" return {self._tokenizer.id_to_piece(i): f"<[{i}]>" for i in range(self._tokenizer.num_special_tokens)} def _special_token_lark(self, token_name: str) -> str: - """Convert special token name to lark grammar syntax.""" + r"""Convert special token name to lark grammar syntax.""" assert token_name in self._special_token_map, f"Unknown special token: {token_name}" return self._special_token_map[token_name]