diff --git a/Cargo.toml b/Cargo.toml index adfda6b9..36287715 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,8 +11,8 @@ categories = ["algorithms", "hardware-support", "science"] license = "MIT" [dependencies] -candle-core = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "1e9d1a9" } -candle-nn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "1e9d1a9" } +candle-core = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "157b048" } +candle-nn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "157b048" } serde = { version = "1.0.190", features = ["serde_derive"] } tokenizers = {version = "0.21.2", features = ["http"] } hf-hub = "0.4.1" @@ -21,7 +21,8 @@ itertools = "0.13.0" akin = "0.4.0" indicatif = "0.17.11" serde_json = "1.0.108" -llguidance = "0.6" +llguidance = { version = "1.6", default-features = false, features = ["lark"] } +toktrie_hf_tokenizers = "1.6" toktrie = "1.4" half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] } tokio = { version = "1.38.0", features = ["sync"] } @@ -35,6 +36,7 @@ interprocess = "2.2.2" serde-big-array = "0.5.1" bincode = { version = "1.3.1" } twox-hash = "2.1.1" +rmp-serde = "1.3.1" rand = "0.9.0" rayon="1.10.0" clap = { version = "4.4.7", features = ["derive"] } @@ -44,7 +46,7 @@ ahash = "0.8.11" reedline = "0.40.0" pyo3 = { version = "0.25.1", features = ["extension-module", "abi3-py38"], optional = true } parking_lot = "0.12.4" -attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.4.1", rev = "af0b475" } +attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.4.1", rev = "29e4beb" } once_cell = "1.21.3" tqdm = "0.8.0" futures = "0.3.31" @@ -60,7 +62,7 @@ utoipa = { version = "4.2", features = ["axum_extras"] } colored = { version = "3.0.0" } tower-http = { version = "0.6.6", features = ["cors"] } rustchatui = { git = "https://github.com/guoqingbao/rustchatui.git", rev = "68caad9" } -sysinfo = "0.37.2" +sysinfo = "0.38.3" image = { version = "0.25.6", default-features = false, features = ['bmp', 'gif', 'jpeg', 'png', 'tiff', 'webp'] } reqwest = { version = "0.12.24", features = ["blocking", "json", "rustls-tls"]} bytemuck = "1.24.0" @@ -89,3 +91,7 @@ python = ["pyo3"] [[bin]] name = "runner" path = "src/runner/runner.rs" + +[[bin]] +name = "special-tokens-extraction" +path = "example/special-tokens-extraction/src/main.rs" diff --git a/ReadMe-CN.md b/ReadMe-CN.md index cdb1a046..57acdde4 100644 --- a/ReadMe-CN.md +++ b/ReadMe-CN.md @@ -319,8 +319,30 @@ cargo install --features metal --- -## 🔌 MCP集成 (工具调用) +## 🔌 LLGuidance 支持(结构化输出与约束) + +vLLM.rs 现在支持通过 llguidance 库实现结构化输出和约束生成: + +- **工具调用优化**:使用 `--enable-tool-grammar` 启用工具调用语法,强制模型输出符合工具参数schema的JSON结构 +- **自定义约束**:使用 `--allow-constraint-api` 允许客户端通过 structured_outputs 或 response_format 提交 Lark/Regex/JSON Schema 约束 +- **正则表达式约束**:强制输出符合特定格式(如电话号码、日期等) +- **JSON Schema 约束**:通过 OpenAI 兼容的 response_format 或 structured_outputs 提交自定义约束 + +**使用示例:** +```bash +# 启用工具调用语法(自动从 MCP 工具构建 LLG 语法) +vllm-rs --m Qwen/Qwen3-30B-A3B-Instruct --enable-tool-grammar --ui-server + +# 启用客户端约束API(允许OpenAI风格的structured_outputs/response_format) +vllm-rs --m Qwen/Qwen3-30B-A3B-Instruct --allow-constraint-api --ui-server +``` + +查看 [**结构化输出文档 →**](docs/llguidance-integration.md) + +--- + +## 🔌 MCP集成 (工具调用) 通过Model Context Protocol让LLM调用外部工具。 ```bash @@ -455,6 +477,8 @@ pip install target/wheels/vllm_rs-*-cp38-abi3-*.whl --force-reinstall | `--kv-fraction` | 用于控制KVCache使用量 (模型加载后剩余可用GPU显存的百分比) | | `--prefix-cache` | 启用前缀缓存,用于多轮对话 | | `--prefix-cache-max-tokens` | 限制前缀缓存大小(按 block size 向下取整) | +| `--allow-constraint-api` | 允许通过HTTP API提交客户端约束(默认:false) | +| `--enable-tool-grammar` | 自动从工具构建LLG语法(默认:false) | ### MCP配置参数 diff --git a/ReadMe.md b/ReadMe.md index 68a3180e..f74f12be 100644 --- a/ReadMe.md +++ b/ReadMe.md @@ -95,9 +95,9 @@ All models support hardware FP8 KV-cache acceleration (requires SM90+ and disabl ## 📘 Usage in Python ### 📦 Install with pip -- 💡 **CUDA compute capability < 8.0** (e.g., V100) requires a **manual build** +- 💡 **CUDA compute capability < 8.0** (e.g., V100) requires a **manual build** (no `flash-attn` support; alternatively use **Rust mode**). -- 💡 The **prebuilt wheel** is built with the `flash-context` feature enabled. +- 💡 The **prebuilt wheel** is built with the `flash-context` feature enabled. To use **FP8 KV Cache**, you must **build manually** (remove the `flash-context` build flag). @@ -284,7 +284,7 @@ Use `--i` to enable interactive mode 🤖, `--ui-server` or `--server` to enable # Metal/MacOS vllm-rs --m Qwen/Qwen3-4B-GGUF --f Qwen3-4B-Q4_K_M.gguf --ui-server --prefix-cache ``` - +
Multi-GPU + Unquantized Model @@ -332,6 +332,28 @@ vllm-rs --m Qwen/Qwen3-4B-Instruct-2507-FP8 --ui-server --prefix-cache --- +## 🔌 LLGuidance Support (Structured Outputs & Constraints) + +vLLM.rs now supports structured output and constraint-based generation via llguidance: + +- **Tool Call Optimization**: Use `--enable-tool-grammar` to auto-build LLG grammar from tools, forcing model to output JSON matching tool parameter schemas +- **Custom Constraints**: Use `--allow-constraint-api` to allow clients to submit Lark/Regex/JSON Schema constraints via OpenAI-compatible structured_outputs/response_format +- **Regex Constraints**: Enforce output formats like phone numbers (`^number\s\d{3}-\d{3}-\d{4}$`) +- **JSON Schema Constraints**: Enforce structured output via response_format or structured_outputs + +**Usage Examples:** +```bash +# Enable tool grammar (auto-builds LLG grammar from MCP tools) +vllm-rs --m Qwen/Qwen3-30B-A3B-Instruct --enable-tool-grammar --ui-server + +# Enable client constraints API (accepts structured_outputs/response_format) +vllm-rs --m Qwen/Qwen3-30B-A3B-Instruct --allow-constraint-api --ui-server +``` + +See [**Structured Outputs Documentation →**](docs/llguidance-integration.md) + +--- + ## 🔌 MCP Integration (Tool Calling) Enable LLMs to call external tools via Model Context Protocol. @@ -434,7 +456,7 @@ PD Disaggregation separates prefill (prompt processing) and decode (token genera ## 📽️ Demo Video -Watch it in action 🎉 +Watch it in action 🎉 @@ -462,19 +484,19 @@ pip install maturin[patchelf] # For Linux/Windows 2. **Build the Python package** ```bash -# Naive CUDA (single GPU only) +# Naive CUDA (single GPU only) maturin build --release --features cuda,python # Naive CUDA (+CUDA Graph, experimental) ./build.sh --release --features cuda,graph,python -# CUDA (with prefix-cache and FP8 KV Cache, no Flash Attention, compatible with V100) +# CUDA (with prefix-cache and FP8 KV Cache, no Flash Attention, compatible with V100) ./build.sh --release --features cuda,nccl,python -# CUDA (+Flash Attention, only used in prefill stage) +# CUDA (+Flash Attention, only used in prefill stage) ./build.sh --release --features cuda,nccl,flash-attn,python -# CUDA (+cutlass (sm90+), +Flash Attention for decoding, +high prefill throughput, long time to build) +# CUDA (+cutlass (sm90+), +Flash Attention for decoding, +high prefill throughput, long time to build) ./build.sh --release --features cuda,nccl,flash-attn,flash-context,cutlass,python # macOS (Metal, single GPU only, with prefix-cache and FP8 kvcache) @@ -518,6 +540,8 @@ pip install target/wheels/vllm_rs-*-cp38-abi3-*.whl --force-reinstall | `--kv-fraction` | control kvcache usage (percentage of remaining gpu memory after model loading) | | `--prefix-cache` | Enable prefix caching for multi-turn conversations | | `--prefix-cache-max-tokens` | Cap prefix cache size in tokens (rounded down to block size) | +| `--allow-constraint-api` | Allow client-submitted constraints via HTTP API (default: false) | +| `--enable-tool-grammar` | Automatically build LLG grammar from tools (default: false) | ### MCP Configuration @@ -563,7 +587,7 @@ pip install target/wheels/vllm_rs-*-cp38-abi3-*.whl --force-reinstall * [x] **Claude/Anthropic-compatible API Server** * [x] **Support CUDA 13** * [x] **Support FlashInfer backend** -* [ ] TentorRT-LLM +* [ ] TentorRT-LLM --- ## 📚 References diff --git a/docs/goose.md b/docs/goose.md index 9d3c76dd..70d6b178 100644 --- a/docs/goose.md +++ b/docs/goose.md @@ -17,35 +17,33 @@ python3 -m vllm_rs.server --m Qwen/Qwen3-30B-A3B-Instruct-2507 --d 0,1 --server ## 2) Configure Goose -### Download and install Goose: https://block.github.io/goose/docs/getting-started/installation/ - ```shell # For non-UI system, export GOOSE_DISABLE_KEYRING=1 ``` - Export empty API KEY ```shell export VLLM_API_KEY="empty" ``` +### Download and install Goose: https://block.github.io/goose/docs/getting-started/installation/ ### Configure goose with `Custom Providers` and API key `empty` ```shell goose configure -┌ goose-configure +┌ goose-configure │ ◇ What would you like to configure? -│ Custom Providers +│ Custom Providers │ ◇ What would you like to do? -│ Add A Custom Provider +│ Add A Custom Provider │ ◇ What type of API is this? -│ OpenAI Compatible +│ OpenAI Compatible │ ◇ What should we call this provider? │ vllm-rs @@ -60,10 +58,10 @@ goose configure │ default │ ◇ Does this provider support streaming responses? -│ Yes +│ Yes │ ◇ Does this provider require custom headers? -│ No +│ No │ └ Custom provider added: vllm-rs └ Configuration saved successfully to /root/.config/goose/config.yaml diff --git a/docs/llguidance-integration.md b/docs/llguidance-integration.md new file mode 100644 index 00000000..f3caea5b --- /dev/null +++ b/docs/llguidance-integration.md @@ -0,0 +1,2175 @@ +# LLGuidance Integration Documentation + +## Overview + +This document provides comprehensive documentation for the llguidance integration in vllm.rs, covering: + +1. **Architecture** - System design and component interactions +2. **Data Flow** - Complete request-to-response flow +3. **API Reference** - All public functions and their signatures +4. **Usage Examples** - Common patterns and use cases +5. **Mathematical Foundations** - How the grammar system works +6. **Rollback Mechanics** - State recovery and consistency +7. **Grammar Construction** - How grammars are composed and merged +8. **Grammar Rule Order** - Correct ordering of Lark grammar rules + +--- + +## 1. ARCHITECTURE + +### Component Overview + +```mermaid +sequenceDiagram + participant User + participant API + participant Pipeline + participant SpecialTokens + participant LLGFactory + participant Matcher + participant TokenParser + participant EarleyParser + participant Lexer + participant TokTrie + participant Sampler + participant LogitsProcessor + participant Model + + User->>API: Request with constraint (regex/json_schema/lark/llguidance) + + Note over User,API: Phase 1: Request Setup and Grammar Building + + API->>SpecialTokens: SpecialTokens::new(&tokenizer) + SpecialTokens-->>API: Return EOS, BOS, TOOL token IDs + API->>Pipeline: build_llg_factory(tokenizer) + Pipeline->>LLGFactory: toktrie_hf_tokenizers::ByteTokenizer::from_tokenizer(tokenizer) + LLGFactory->>TokTrie: Create token trie from tokenizer vocabulary + TokTrie-->>LLGFactory: Return TokEnv with trie + LLGFactory->>LLGFactory: ParserFactory::new_simple(&env) + LLGFactory-->>Pipeline: Return Arc + + Pipeline->>Pipeline: llg_grammar_from_constraint(&request.constraint) + Pipeline->>Matcher: constraint_from_llg_grammar(&factory, grm) + Matcher->>Matcher: factory.create_parser(grm) + Matcher->>TokenParser: Create with grammar_init + TokenParser->>EarleyParser: Build CGrammar from grammar + TokenParser->>Lexer: Build LexerSpec from grammar + Lexer->>TokTrie: Precompute large lexemes if needed + TokTrie-->>Lexer: Return optimized lexeme sets + + Note over User,Matcher: Phase 2: Prompt Processing (if needed) + + User->>API: Optional: process_prompt(prompt_tokens) + API->>TokenParser: process_prompt(prompt_tokens) + TokenParser->>TokenParser: tokenize_bytes_marker(&prompt_bytes) + TokenParser->>TokenParser: process_prompt() returns new prompt + + Note over User,Matcher: Phase 3: Inference Loop + + loop for each token generation + + Model->>Model: Forward pass on input tokens + Model-->>Pipeline: Return logits tensor + + Pipeline->>Sampler: sample_sequence(logits, seq, ...) + + Note over Sampler: Two-stage sampling with llguidance + + Sampler->>LogitsProcessor: Apply llguidance constraint + + LogitsProcessor->>TokenParser: compute_mask() + TokenParser->>TokenParser: compute_mask_inner() + TokenParser->>EarleyParser: run_speculative("compute_mask") + EarleyParser->>EarleyParser: trie_started("compute_mask") + EarleyParser->>EarleyParser: compute_bias() + EarleyParser->>Lexer: compute_bias() with token_prefix + + Note over Lexer,TokTrie: Lexical Scope Analysis + + Lexer->>TokTrie: Walk token trie for allowed lexemes + TokTrie-->>Lexer: Return SimpleVob bit mask + + Lexer->>EarleyParser: Return mask to TokenParser + TokenParser->>TokenParser: cache mask for fast-forward + + TokenParser-->>LogitsProcessor: Return SimpleVob mask + + LogitsProcessor->>LogitsProcessor: Check if sampled token is allowed + LogitsProcessor->>Sampler: Apply logit biasing + + alt Token is allowed + Sampler->>Sampler: No biasing needed + else Token is not allowed + Sampler->>Sampler: Set invalid tokens to -f32::INFINITY + Sampler->>Sampler: Re-sample with biased logits + end + + Sampler->>TokenParser: consume_token(sampled_token) + TokenParser->>TokenParser: apply_token(sampled_token) + TokenParser->>TokenParser: llm_tokens.push(sampled_token) + TokenParser->>TokenParser: llm_bytes.extend(token_bytes) + TokenParser->>EarleyParser: parser.apply_token(token_bytes, token_id) + EarleyParser->>Lexer: advance lexer state + Lexer->>Lexer: Update lexer_stack with new state + Lexer->>EarleyParser: Return backtrack count + + alt Backtrack needed + EarleyParser->>EarleyParser: rollback(backtrack_bytes) + EarleyParser->>EarleyParser: Update llm_tokens and llm_bytes + end + + TokenParser->>TokenParser: check_stop() + TokenParser-->>Sampler: Return CommitResult + + Note over Sampler: Phase 4: Fast-Forward (if enabled) + + Sampler->>TokenParser: compute_ff_tokens() + TokenParser->>TokenParser: ff_tokens() + TokenParser->>TokTrie: Tokenize forced bytes + TokTrie-->>TokenParser: Return fast-forward tokens + + alt Fast-forward tokens available + TokenParser->>TokenParser: consume_ff_tokens() + loop for each ff_token + TokenParser->>TokenParser: consume_token(ff_token) + TokenParser->>TokenParser: llm_tokens.push(ff_token) + TokenParser->>TokenParser: llm_bytes.extend(ff_token_bytes) + end + end + + Note over Sampler: Phase 5: Speculative Decoding (if enabled) + + Model->>Model: Draft model forward pass + Model-->>Pipeline: Return draft logits + + Pipeline->>Sampler: sample_target_sequence_speculative() + Sampler->>TokenParser: rollback(n_toks) + TokenParser->>EarleyParser: parser.rollback(bytes_to_drop) + EarleyParser->>Lexer: pop lexer states + Lexer-->>TokenParser: Return rollback result + + Sampler->>Sampler: Sample draft tokens + Sampler->>TokenParser: validate_tokens(draft_tokens) + TokenParser->>TokenParser: consume_token(draft_token) + + alt Draft token accepted + TokenParser->>TokenParser: Continue with next draft + else Draft token rejected + TokenParser->>TokenParser: Accept partial draft + TokenParser->>TokenParser: Rollback to last valid state + end + + end + + Note over User,Matcher: Phase 6: Token Geometry and Binary Data State + + TokTrie->>TokTrie: Token encoding (8:24 bit split) + TokTrie->>TokTrie: node.bits = (token_id << 8) | byte + TokTrie->>TokTrie: node.bits2 = (subtree_size << 10) | num_parents + + TokTrie->>SimpleVob: Bit mask storage + SimpleVob->>SimpleVob: data: Vec (32 tokens per word) + SimpleVob->>SimpleVob: allow_token(tok): data[tok>>5] |= 1 << (tok&31) + + Note over User,Matcher: Phase 7: Rollback and Verification + + TokenParser->>TokenParser: validate_tokens(tokens) + TokenParser->>EarleyParser: validate_tokens_raw(tokens) + EarleyParser->>Lexer: Check if tokens match current lexer state + Lexer-->>TokenParser: Return number of valid tokens + + TokenParser->>TokenParser: rollback(n_tokens) + TokenParser->>EarleyParser: parser.rollback(bytes_to_drop) + EarleyParser->>Lexer: pop lexer states + TokenParser->>TokenParser: llm_tokens.truncate(new_len) + TokenParser->>TokenParser: llm_bytes.truncate(new_len) + + Note over User,Matcher: Phase 8: Response Generation + + Pipeline->>API: Return completion with tokens + API->>User: Stream or return final response + end +``` + +### Key Data Structures + +#### `TopLevelGrammar` Struct +**Location**: [`src/utils/guidance.rs:515-523`](src/utils/guidance.rs:515-523) + +The `TopLevelGrammar` is provided by llguidance and represents a complete grammar specification. + +```rust +pub struct GuidanceState { + matcher: Matcher, // llguidance Matcher instance + llm_tokens: Vec, // Track committed tokens for rollback + llm_bytes: usize, // Track byte position for rollback + slicer_cache: SlicerCache, // Cache for precomputed mask slices +} +``` + +#### `SamplingParams` Struct +**Location**: [`src/utils/config.rs`](src/utils/config.rs) + +```rust +pub struct SamplingParams { + pub temperature: Option, + pub max_tokens: Option, + pub ignore_eos: bool, + pub top_k: Option, + pub top_p: Option, + pub session_id: Option, + pub frequency_penalty: Option, + pub presence_penalty: Option, + pub stop_sequences: Option>, + // stop_token_ids removed - now uses SpecialTokens for EOS detection + pub mcp_mode: Option, // Tool call mode + pub grammar: Option, // LLG constraint (TopLevelGrammar) + pub thinking: Option, + #[cfg(feature = "python")] + pub grammar_json: Option, // Grammar as JSON string for Python API +} +``` + +**Note**: The `stop_token_ids` field has been removed. Stop sequences are now resolved using the engine's `SpecialTokens` instance for consistent EOS detection across the system. + +#### `EngineConfig` Struct +**Location**: [`src/utils/config.rs:244-289`](src/utils/config.rs:244-289) + +```rust +pub struct EngineConfig { + pub model_id: Option, + pub weight_path: Option, + pub weight_file: Option, + pub enforce_parser: Option, + // ... other fields ... + pub allow_constraint_api: bool, // Allow client constraints via API + pub enable_tool_grammar: bool, // Auto-generate tool grammars from resolved tools + // ... other fields ... +} +``` + +#### Constraint Building Functions + +The server layer provides several functions to convert client requests to `TopLevelGrammar`: + +| Function | Location | Description | +|----------|----------|-------------| +| `grammar_fragment_from_structured_outputs()` | [`src/server/mod.rs:167`](src/server/mod.rs:167) | Convert StructuredOutputs to TopLevelGrammar | +| `grammar_fragment_from_response_format()` | [`src/server/mod.rs:248`](src/server/mod.rs:248) | Convert ResponseFormat to TopLevelGrammar | + +### Grammar Composition + +The `compose_grammars()` function ([`src/utils/guidance.rs:581`](src/utils/guidance.rs:581)) handles grammar composition: + +```rust +pub fn compose_grammars( + constraint_grammars: Vec, // Client-provided constraints + tool_grammar: Option, // Tool call grammar (if enabled) + has_tools: bool, // Whether tools are present + tool_choice_required: bool, // Whether tool_choice is "required" + forced_tool_name: Option, // Specific tool forced via tool_choice + max_tokens: Option, + special_tokens: &SpecialTokens, // ← EOS token IDs for TEXT patterns +) -> TopLevelGrammar +``` + +This function handles 8 different scenarios: +1. No constraint, no tools → text with EOS bounding +2. No constraint, tools optional → tool_call | text with EOS +3. No constraint, tools required → tool_call only +4. No constraint, tools optional, specific tool forced → tool_call only +5. Constraint only, no tools → constraint only +6. Constraint only, tools optional → constraint | tool_call +7. Constraint only, tools required → constraint | tool_call +8. Constraint only, specific tool forced → constraint | tool_call + +**Note**: The `special_tokens: &SpecialTokens` parameter provides EOS token IDs for proper freeform text termination via [`chat_text_expression_with_eos()`](src/utils/guidance.rs:485). + +--- + +## 2. DATA FLOW + +### Full Request Flow + +``` +User Request + │ + ▼ +[server/server.rs:250] chat_completion() + │ + ├─ Parse request fields (messages, tools, tool_choice, etc.) + ├─ Resolve tools (MCP + request tools) + │ + ├─ Build constraint grammars from structured_outputs/response_format + │ ├─ grammar_fragment_from_structured_outputs() [server/mod.rs:167] + │ │ └─ Handles choice, regex, json, grammar, structural_tag + │ └─ grammar_fragment_from_response_format() [server/mod.rs:248] + │ └─ Handles response_format with json_schema type + │ + ├─ Build tool grammar if enable_tool_grammar=true + │ [tools/schema.rs:87] build_json_tool_lark_grammar() + │ └─ Creates Lark grammar with tool schemas as %json directives + │ + ├─ Initialize SpecialTokens for EOS detection + │ [utils/special_tokens.rs:133] SpecialTokens::new(&tokenizer) + │ └─ Extracts EOS, BOS, PAD, TOOL, FUNCTION token IDs from tokenizer + │ + ├─ Compose grammars via compose_grammars() [utils/guidance.rs:581] + │ └─ Passes special_tokens: &SpecialTokens for EOS handling + │ + ▼ +[core/engine.rs:1298] generate_stream() + │ + ├─ Apply chat template to messages + ├─ Create Sequence with grammar in SamplingParams + ├─ Allocate KV cache blocks + └─ Initialize GuidanceState if grammar exists + [utils/guidance.rs:526] GuidanceState::new_from_grammar() + ├─ [llguidance] ParserFactory::create_parser(grammar)? + ├─ [llguidance] Matcher::new(Ok(parser)) + │ └─ TokenParser created with Grammar + │ └─ EarleyParser with LexerSpec + │ └─ RegexVec for token matching + └─ GuidanceState initialized with matcher + │ + ▼ +[core/scheduler.rs:616] postprocess() + │ + ├─ For each generated token: + ├─ [core/runner.rs:1597] validate_sequence_for_grammar() + │ └─ [utils/guidance.rs:672] GuidanceState::validate_tokens(output_ids) + │ └─ matcher.validate_tokens() → Option + │ + └─ If validation fails (< output_ids.len()): + └─ [scheduler.rs:227] rollback_sequence() + ├─ Save rollback snapshot + ├─ Truncate token_ids, block_table + ├─ [core/block_manager.rs:946] rollback_to_seq_tokens() + │ └─ Release blocks, clean prefix cache + └─ [core/runner.rs:1607] rollback_sequence_for_guidance() + └─ [utils/guidance.rs:645] GuidanceState::rollback_to() + └─ matcher.rollback(tokens_to_rollback)? + │ + ▼ +[core/runner.rs:1100] sample() + │ + ├─ [utils/guidance.rs:543] GuidanceState::compute_mask() + │ └─ matcher.compute_mask() → SimpleVob + │ + ├─ Apply mask to logits (set invalid to -inf) + │ + └─ [utils/guidance.rs:617] GuidanceState::consume_ff_tokens() + ├─ matcher.compute_ff_tokens() → Vec + └─ For each token: consume + update state + │ + ▼ +Response to Client +``` + +### Streaming Tool Call Flow + +``` +[server/parser.rs:488] StreamToolParser::process_token() + │ + ├─ ParserState::Normal + │ ├─ Check for start tag (, [TOOL_CALLS], etc.) + │ └─ [server/parser.rs:527] is_start_token() + │ └─ Token ID match OR text match + │ + ├─ ParserState::Buffering + │ ├─ Accumulate tokens in buffer + │ ├─ [tool_parser crate] parse_incremental() + │ ├─ Check for end tag + │ └─ [server/parser.rs:587] is_end_token() + │ + └─ ParserState::ToolCalls + ├─ [server/parser.rs:766] build_tool_calls_with_fallback() + ├─ [server/parser.rs:946] parse_complete_with_fallback() + │ ├─ QwenCoder XML parsing + │ ├─ JSON array parsing + │ └─ JSON object parsing + └─ [tools/helpers.rs:102] filter_tool_calls() +``` + +--- + +## 3. GRAMMAR CONSTRUCTION + +### Overview + +The grammar construction system in vllm.rs uses llguidance's `TopLevelGrammar` to represent constraints. When multiple grammars need to be combined (e.g., `constraint | tool_call`), the `compose_grammars()` function handles the composition logic. + +### Grammar Composition Logic + +The `compose_grammars()` function ([`src/utils/guidance.rs:363`](src/utils/guidance.rs:363)) handles 8 different scenarios based on: +- Whether constraint grammars are present +- Whether tool grammars are available +- Whether tool_choice is "required" +- Whether a specific tool is forced + +| Constraint | Tools | tool_choice | Result | +|------------|-------|-------------|---------| +| None | None | - | TEXT only | +| None | Yes | Optional | TEXT \| tool_call | +| None | Yes | Required | tool_call only | +| Yes | None | - | constraint only | +| Yes | Yes | Optional | constraint \| tool_call | +| Yes | Yes | Required | constraint \| tool_call | + +### Lark Grammar Rule Order + +**Critical Requirement**: In Lark grammars, rules must be ordered such that: +1. The `start:` rule is defined FIRST (it's the entry point) +2. Rules that are referenced by other rules must be defined BEFORE those rules +3. Helper rules like `ws:` (whitespace) are defined LAST + +**Incorrect order** (causes parsing errors): +```lark +start: TEXT | tool_call +tool_call: "<‌tool_call>" ws json_array ws "<‌/tool_call>" +json_array: "[" obj ("," obj)* "]" +obj: obj_search | obj_weather +ws: /[ \t\r\n]+/ ← Helper rule defined BEFORE referenced rules +``` + +**Correct order** (dependencies first, helpers last): +```lark +start: TEXT | tool_call +tool_call: "<‌tool_call>" ws json_array ws "<‌/tool_call>" +json_array: "[" obj ("," obj)* "]" +obj: obj_search | obj_weather +obj_search: %json {...} +obj_weather: %json {...} +ws: /[ \t\r\n]+/ ← Helper rule defined LAST +``` + +### Grammar Composition + +```mermaid +sequenceDiagram + participant User + participant Server + participant ConstraintBuilder + participant ToolGrammarBuilder + participant ComposeLogic + participant LarkParser + participant EarleyCompiler + + User->>Server: Request with tools + structured_outputs + Server->>ConstraintBuilder: grammar_fragment_from_structured_outputs() + ConstraintBuilder->>LarkParser: Parse JSON schema + + LarkBuilder->>ConstraintBuilder: Return TopLevelGrammar + Server->>ToolGrammarBuilder: build_json_tool_lark_grammar() if enabled + + ToolGrammarBuilder->>LarkParser: Build tool call grammar + LarkParser->>ToolGrammarBuilder: Return tool TopLevelGrammar + + Note over ComposeLogic: compose_grammars() + + ComposeLogic->>ComposeLogic: Determine match arm based on: + ComposeLogic->>ComposeLogic: - constraint_grammars length + ComposeLogic->>ComposeLogic: - tool_grammar presence + ComposeLogic->>ComposeLogic: - tool_choice_required + ComposeLogic->>ComposeLogic: - forced_tool_name presence + + ComposeLogic->>ComposeLogic: If multiple grammars: merge_top_level_grammars() + ComposeGrammar->>EarleyCompiler: Compile direct alternation + EarleyCompiler->>LexerBuilder: Build lexer spec + + LexerBuilder->>ComposeLogic: Return single TopLevelGrammar + + Note over ComposeLogic: Generated grammar: + start: TEXT | tool_call + TEXT: /[\x20-\x7E\x80-\uFFFF\n\r\t]/ + tool_call: "<‌tool_call>" ws json_array ws "<‌/tool_call>" + json_array: "[" obj ("," obj)* "]" + obj_search: %json {...} + obj_weather: %json {...} + obj: obj_search | obj_weather + ws: /[ \t\r\n]+/ +``` + + + +### Helper Functions + + + +**Location**: [`src/utils/guidance.rs:161-206`](src/utils/guidance.rs:161-206) + +This function combines multiple `TopLevelGrammar` objects into a single grammar with direct alternation at the start rule. + +```rust +pub fn merge_top_level_grammars( + grammars: Vec, + max_tokens: Option, + start_separator: Option, +) -> TopLevelGrammar { + // Extract all Lark grammar strings + let mut lark_parts = Vec::new(); + + for (_i, g) in grammars.iter().enumerate() { + for gw in &g.grammars { + if let Some(lark) = &gw.lark_grammar { + lark_parts.push(lark.clone()); + } + } + } + + if lark_parts.is_empty() { + let lark_start_exp = format!("start: TEXT\n{}", chat_text_expression()); + let mut tlg = TopLevelGrammar::from_lark(lark_start_exp); + tlg.max_tokens = max_tokens; + return tlg; + } + + // Parse each grammar and extract start RHS + other rules + let mut combined_start_rhs = Vec::new(); + let mut all_other_rules = Vec::new(); + + for lark in lark_parts.iter() { + let (start_rhs, other_rules) = parse_lark_grammar(lark); + combined_start_rhs.push(start_rhs); + all_other_rules.extend(other_rules); + } + + // Combine all other rules, handling duplicates + let combined_rules = combine_rules(all_other_rules); + + // Build new grammar with direct alternation at start + let start_separator = format!(" {} ", &start_separator.unwrap_or_else(|| "|".to_string())); + let start_alternation = combined_start_rhs.join(&start_separator); + let final_grammar = format!("start: {}\n{}", start_alternation, combined_rules); + + let mut top_gram = TopLevelGrammar::from_lark(final_grammar); + top_gram.max_tokens = max_tokens; + top_gram +} +``` + +#### `parse_lark_grammar()` + +**Location**: [`src/utils/guidance.rs:85-114`](src/utils/guidance.rs:85-114) + +Extracts the start rule RHS and other rules from a Lark grammar string. + +```rust +fn parse_lark_grammar(lark: &str) -> (String, Vec) { + let lines: Vec<&str> = lark.lines().collect(); + if lines.is_empty() { + return (String::new(), Vec::new()); + } + + let first_line = lines[0].trim(); + if first_line.starts_with("start:") { + // Extract only the rule names after "start:", not the full rule definition + let rhs_part = first_line.strip_prefix("start:").unwrap_or("").trim(); + + // Parse the RHS to get individual rule names (separated by |) + let rule_names: Vec = rhs_part + .split('|') + .map(|s| s.trim().to_string()) + .collect(); + let start_rhs = rule_names.join(" | "); + + // Return all remaining lines as other rules + let other_rules: Vec = lines[1..].iter().map(|s| s.to_string()).collect(); + (start_rhs, other_rules) + } else { + // No start rule - treat entire grammar as the start rule + (lark.to_string(), Vec::new()) + } +} +``` + +#### `combine_rules()` + +**Location**: [`src/utils/guidance.rs:117-156`](src/utils/guidance.rs:117-156) + +Merges grammar rules, handling duplicate rule names by combining them with alternation. + +```rust +fn combine_rules(rules: Vec) -> String { + if rules.is_empty() { + return String::new(); + } + + use std::collections::HashMap; + let mut rule_groups: HashMap> = HashMap::new(); + + for rule in rules { + let rule = rule.trim(); + if rule.is_empty() { + continue; + } + + // Find the rule name (before the first ":") + if let Some(colon_pos) = rule.find(':') { + let name = rule[..colon_pos].trim().to_string(); + let body = rule[colon_pos + 1..].trim().to_string(); + rule_groups.entry(name).or_default().push(body); + } else { + // Rule without colon - add as-is + rule_groups.entry("anonymous".to_string()).or_default().push(rule.to_string()); + } + } + + // Reconstruct rules, merging duplicates + let mut combined = Vec::new(); + for (name, bodies) in rule_groups { + if bodies.len() == 1 { + combined.push(format!("{}: {}", name, bodies[0])); + } else { + // Multiple definitions for same rule - combine with alternation + combined.push(format!("{}: {}", name, bodies.join(" | "))); + } + } + + combined.join("\n") +} +``` + +### Tool Grammar Construction + +**Location**: [`src/tools/schema.rs:87-145`](src/tools/schema.rs:87-145) + +The `build_json_tool_lark_grammar()` function creates properly ordered tool grammars: + +```rust +pub fn build_json_tool_lark_grammar( + tools: &[Tool], + start: &str, + end: &str, + start_is_special: bool, + end_is_special: bool, +) -> llguidance::api::TopLevelGrammar { + let lark = build_json_tool_lark_string(tools, start, end, start_is_special, end_is_special); + top_level_grammar_from_lark(&lark) +} + +fn build_json_tool_lark_string( + tools: &[Tool], + start: &str, + end: &str, + start_is_special: bool, + end_is_special: bool, +) -> String { + let mut obj_rules = Vec::new(); + for tool in tools { + let tool_name = tool.function.name.replace("-", "_"); + let schema_str = serde_json::to_string(&tool.function.parameters).unwrap_or_default(); + obj_rules.push(format!("obj_{}: %json {}", tool_name, schema_str)); + } + let start_tag = if start.is_empty() { + "<‌tool_call>".to_string() + } else { + lark_literal(start, start_is_special) + }; + let end_tag = if end.is_empty() { + "<‌/tool_call>".to_string() + } else { + lark_literal(end, end_is_special) + }; + let ws = lark_ws_regex(); + + // Build the complete grammar with rules in correct dependency order + let mut all_rules = Vec::new(); + all_rules.push("start: tool_call".to_string()); + all_rules.push(format!("tool_call: {} ws json_array ws {}", start_tag, end_tag)); + all_rules.push("json_array: \"[\" obj (\",\" obj)* \"]\"".to_string()); + + if obj_rules.is_empty() { + // No tools - use a generic object schema + all_rules.push("obj: %json {\"type\": \"object\"}".to_string()); + } else { + // Individual obj_* rules must come BEFORE the obj: rule that references them + all_rules.extend(obj_rules.clone()); + // obj: rule references all obj_* rules via alternation + all_rules.push(format!("obj: {}", obj_rules.iter().map(|r| { + r.trim().split(':').next().unwrap_or("obj").to_string() + }).collect::>().join(" | "))); + } + // ws comes LAST - it's a helper rule for whitespace + all_rules.push(format!("ws: {}", ws)); + + all_rules.join("\n") + "\n" +} +``` + +**Generated grammar order**: +```lark +start: tool_call +tool_call: <‌tool_call> ws json_array ws <‌/tool_call> +json_array: "[" obj ("," obj)* "]" +obj_search: %json {"type":"object","properties":{...}} +obj_weather: %json {"type":"object","properties":{...}} +obj: obj_search | obj_weather +ws: /[ \t\r\n]+/ +``` + +**Note**: The tool call tags (`<‌tool_call>`, `<‌/tool_call>`) are currently specified as **string literals** in the Lark grammar, not as token IDs. This requires the tokenizer to recognize these exact byte sequences as single tokens. + +For token ID support, the `lark_special_token()` function exists ([`src/tools/schema.rs:60-67`](src/tools/schema.rs:60-67)) but is not currently used in tool grammar construction. To use token IDs instead of strings, the grammar would need to be generated with `<[151657]>` syntax where 151657 is the token ID for `<‌tool_call>`. + +**Token ID-based alternative** (not currently implemented): +```lark +start: tool_call +tool_call: <[151657]> ws json_array ws <[151658]> +json_array: "[" obj ("," obj)* "]" +... +``` + +--- + +## 4. API REFERENCE + +### Grammar Fragment Building Functions + +These functions convert client request fields to `TopLevelGrammar` objects. + +#### `grammar_fragment_from_structured_outputs()` +**Location**: [`src/server/mod.rs:167`](src/server/mod.rs:167) + +Converts `StructuredOutputs` to `TopLevelGrammar`: + +```rust +pub fn grammar_fragment_from_structured_outputs( + structured: &StructuredOutputs +) -> Result> +``` + +**Parameters**: +- `structured.choice`: Vec → Lark grammar for enum +- `structured.regex`: String → Regex constraint +- `structured.json`: Value → JSON Schema constraint +- `structured.grammar`: String → Lark grammar +- `structured.structural_tag`: Value → QwenCoder-style XML envelope + +**Returns**: `Some(TopLevelGrammar)` or `None` if invalid + +**Logging**: +- `DEBUG`: Building constraint grammar +- `INFO`: Completed with grammar type + +#### `grammar_fragment_from_response_format()` +**Location**: [`src/server/mod.rs:248`](src/server/mod.rs:248) + +Converts `ResponseFormat` to `TopLevelGrammar`: + +```rust +pub fn grammar_fragment_from_response_format( + response_format: &ResponseFormat +) -> Result> +``` + +**Parameters**: +- `response_format.format_type`: Must be "json_schema" +- `response_format.json_schema.schema`: JSON Schema value + +**Returns**: `Some(TopLevelGrammar)` or error + +**Logging**: +- `DEBUG`: Building JSON schema grammar +- `INFO`: Completed with grammar + +### GuidanceState Methods + +#### `new_from_grammar()` +**Location**: [`src/utils/guidance.rs:526`](src/utils/guidance.rs:526) + +Creates a new GuidanceState from a TopLevelGrammar: + +```rust +pub fn new_from_grammar(factory: Arc, grammar: &TopLevelGrammar) -> Result +``` + +**Parameters**: +- `factory`: Arc - llguidance parser factory +- `grammar`: &TopLevelGrammar - The grammar to parse + +**Returns**: `Result` + +**Flow**: +1. `factory.create_parser(grammar)?` → Parser +2. `Matcher::new(Ok(parser))` → Matcher +3. Initialize `llm_tokens`, `llm_bytes`, `slicer_cache` + +**Logging**: +- `DEBUG`: Constraint type +- `DEBUG`: Grammar converted +- `INFO`: GuidanceState created successfully + +#### `compute_mask()` +**Location**: [`src/utils/guidance.rs:543`](src/utils/guidance.rs:543) + +Computes valid token mask: + +```rust +pub fn compute_mask(&mut self) -> Result> +``` + +**Returns**: `Option` with valid token indices, or None if matcher stopped + +**Logging**: +- `TRACE`: Mask computed with N valid tokens + +#### `validate_tokens()` +**Location**: [`src/utils/guidance.rs:672`](src/utils/guidance.rs:672) + +Validates a sequence of tokens: + +```rust +pub fn validate_tokens(&mut self, tokens: &[u32]) -> Option +``` + +**Parameters**: +- `tokens`: &[u32] - Tokens to validate + +**Returns**: `Some(valid_token_count)` or `None` if validation failed + +**Logging**: +- `DEBUG`: Token X rejected by grammar (if invalid) + +#### `commit_token()` +**Location**: [`src/utils/guidance.rs:556`](src/utils/guidance.rs:556) + +Commits a token to the grammar state: + +```rust +pub fn commit_token(&mut self, token: u32) -> Result<()> +``` + +**Parameters**: +- `token`: u32 - Token ID to commit + +**Flow**: +1. `matcher.consume_token(token)?` +2. `llm_tokens.push(token)` +3. `llm_bytes += 4` (approximate bytes per token) + +**Logging**: +- `TRACE`: Token consumed successfully + +#### `consume_ff_tokens()` +**Location**: [`src/utils/guidance.rs:617`](src/utils/guidance.rs:617) + +Consumes fast-forward tokens guaranteed by grammar: + +```rust +pub fn consume_ff_tokens(&mut self) -> Result, anyhow::Error> +``` + +**Returns**: Vec of consumed FF tokens + +**Flow**: +1. `matcher.compute_ff_tokens()` → Vec +2. For each token: `consume_token()` + `llm_tokens.push()` + `llm_bytes += 4` + +**Logging**: +- `DEBUG`: consume_ff_tokens() called +- `DEBUG`: compute_ff_tokens() returned N tokens +- `DEBUG`: Successfully consumed N tokens + +#### `rollback_to()` +**Location**: [`src/utils/guidance.rs:645`](src/utils/guidance.rs:645) + +Rolls back to a previous state: + +```rust +pub fn rollback_to(&mut self, token_pos: usize, byte_pos: usize) -> Result<()> +``` + +**Parameters**: +- `token_pos`: usize - Target token position +- `byte_pos`: usize - Target byte position + +**Flow**: +1. Calculate `tokens_to_rollback = llm_tokens.len() - token_pos` +2. `matcher.rollback(tokens_to_rollback)?` +3. `llm_tokens.truncate(token_pos)` +4. `llm_bytes = byte_pos` + +**Logging**: +- `DEBUG`: Rollback N tokens successful + +#### `num_tokens()` +**Location**: [`src/utils/guidance.rs:571`](src/utils/guidance.rs:571) + +Returns the number of committed tokens: + +```rust +pub fn num_tokens(&self) -> usize +``` + +#### `num_bytes()` +**Location**: [`src/utils/guidance.rs:576`](src/utils/guidance.rs:576) + +Returns the number of committed bytes: + +```rust +pub fn num_bytes(&self) -> usize +``` + +#### `is_finished()` +**Location**: [`src/utils/guidance.rs:581`](src/utils/guidance.rs:581) + +Checks if guidance is finished: + +```rust +pub fn is_finished(&self) -> bool +``` + +#### `last_token()` +**Location**: [`src/utils/guidance.rs:586`](src/utils/guidance.rs:586) + +Gets the last committed token: + +```rust +pub fn last_token(&self) -> Option +``` + +#### `validate_token()` +**Location**: [`src/utils/guidance.rs:591`](src/utils/guidance.rs:591) + +Validates a single token without consuming: + +```rust +pub fn validate_token(&mut self, token: u32) -> bool +``` + +**Returns**: `true` if valid, `false` if rejected + +**Logging**: +- `DEBUG`: Token rejected by grammar (if invalid) + +#### `compute_mask_or_eos()` +**Location**: [`src/utils/guidance.rs:604`](src/utils/guidance.rs:604) + +Computes valid token mask or EOS set: + +```rust +pub fn compute_mask_or_eos(&mut self) -> Result +``` + +**Returns**: `SimpleVob` with valid token indices + +#### `compute_ff_tokens()` +**Location**: [`src/utils/guidance.rs:609`](src/utils/guidance.rs:609) + +Computes fast-forward tokens without consuming: + +```rust +pub fn compute_ff_tokens(&mut self) -> Vec +``` + +**Returns**: Vec of FF tokens + +#### `has_pending_lexeme_bytes()` +**Location**: [`src/utils/guidance.rs:640`](src/utils/guidance.rs:640) + +Checks if there are pending lexeme bytes: + +```rust +pub fn has_pending_lexeme_bytes(&self) -> bool +``` + +#### `capture_snapshot()` +**Location**: [`src/utils/guidance.rs:656`](src/utils/guidance.rs:656) + +Captures current state as rollback snapshot (no-op in current implementation): + +```rust +pub fn capture_snapshot(&mut self) +``` + +#### `clear()` +**Location**: [`src/utils/guidance.rs:660`](src/utils/guidance.rs:660) + +Clears all state: + +```rust +pub fn clear(&mut self) +``` + +### Helper Functions + +#### `compose_grammars()` +**Location**: [`src/utils/guidance.rs:363`](src/utils/guidance.rs:363) + +Composes multiple grammars into a single TopLevelGrammar: + +```rust +pub fn compose_grammars( + constraint_grammars: Vec, + tool_grammar: Option, + has_tools: bool, + tool_choice_required: bool, + forced_tool_name: Option, + max_tokens: Option, +) -> TopLevelGrammar +``` + +See Section 3 for full documentation. + +#### `chat_text_expression_with_eos()` + +**Location**: [`src/utils/guidance.rs:485-514`](src/utils/guidance.rs:485-514) + +Returns the TEXT pattern with explicit EOS token IDs for free-form text matching with proper termination: + +```rust +pub fn chat_text_expression_with_eos(special_tokens: &SpecialTokens) -> String { + let eos_token_ids = special_tokens.eos_ids(); + + // First check environment variable override + if let Ok(val) = std::env::var("VLLM_LLG_DEFAULT_TEXT") { + return format!("{}", val); + } + + // Build EOS alternation pattern using <[id]> syntax for token IDs + if eos_token_ids.is_empty() { + // Fallback to stop="" when no EOS tokens available + r#"start: text +text[stop=""]: /((?s).*?)/"#.to_string() + } else if eos_token_ids.len() == 1 { + format!(r#"start: text_with_eos +text_with_eos: TEXT eos? +TEXT: /(?s:.*)/ +eos: <[{}]>"#, eos_token_ids[0]) + } else { + let ids: Vec = eos_token_ids.iter().map(|id| format!("<[{}]>", id)).collect(); + let eos_alternation = ids.join(" | "); + format!(r#"start: text_with_eos +text_with_eos: TEXT eos? +TEXT: /(?s:.*)/ +eos: {}"#, eos_alternation) + } +} +``` + +This function: +1. Extracts EOS token IDs from `SpecialTokens` +2. Builds a TEXT pattern with optional EOS termination (`eos?`) +3. Uses `<[token_id]>` syntax for token ID references in the Lark grammar +4. Falls back to `stop=""` pattern when no EOS tokens are available + +#### `merge_top_level_grammars()` +**Location**: [`src/utils/guidance.rs:161`](src/utils/guidance.rs:161) + +Merges multiple TopLevelGrammar objects with direct alternation: + +```rust +pub fn merge_top_level_grammars( + grammars: Vec, + max_tokens: Option, + start_separator: Option, +) -> TopLevelGrammar +``` + +#### `build_tool_call_lark()` +**Location**: [`src/utils/guidance.rs:483`](src/utils/guidance.rs:483) + +Builds Lark grammar string for tool calls: + +```rust +pub fn build_tool_call_lark( + tools: &[Tool], + schema_map: &Arc>, + start: &str, + end: &str, +) -> String +``` + +#### `lark_ws_regex()` +**Location**: [`src/utils/guidance.rs:478`](src/utils/guidance.rs:478) + +Returns the whitespace regex pattern for Lark grammars: + +```rust +pub fn lark_ws_regex() -> &'static str +``` + +#### `chat_text_expression()` +**Location**: [`src/utils/guidance.rs:306`](src/utils/guidance.rs:306) + +Returns the TEXT pattern for free-form text matching: + +```rust +pub fn chat_text_expression() -> String +``` + +#### `sanitize_to_ascii()` +**Location**: [`src/utils/guidance.rs:16`](src/utils/guidance.rs:16) + +Sanitizes a string by removing non-ASCII bytes: + +```rust +pub fn sanitize_to_ascii(s: &str) -> String +``` + +#### `sanitize_utf8_valid()` +**Location**: [`src/utils/guidance.rs:24`](src/utils/guidance.rs:24) + +Sanitizes a string by removing invalid UTF-8 sequences: + +```rust +pub fn sanitize_utf8_valid(s: &str) -> String +``` + +#### `top_level_grammar_from_regex()` +**Location**: [`src/utils/guidance.rs:36`](src/utils/guidance.rs:36) + +Creates TopLevelGrammar from regex: + +```rust +pub fn top_level_grammar_from_regex(regex: &str) -> TopLevelGrammar +``` + +#### `top_level_grammar_from_lark()` +**Location**: [`src/utils/guidance.rs:42`](src/utils/guidance.rs:42) + +Creates TopLevelGrammar from Lark string: + +```rust +pub fn top_level_grammar_from_lark(lark: &str) -> TopLevelGrammar +``` + +#### `top_level_grammar_from_json_schema()` +**Location**: [`src/utils/guidance.rs:48`](src/utils/guidance.rs:48) + +Creates TopLevelGrammar from JSON schema: + +```rust +pub fn top_level_grammar_from_json_schema(schema: serde_json::Value) -> Result +``` + +#### `get_lark_from_top_level_grammar()` +**Location**: [`src/utils/guidance.rs:209`](src/utils/guidance.rs:209) + +Extracts the Lark grammar string from TopLevelGrammar: + +```rust +pub fn get_lark_from_top_level_grammar(gram: &TopLevelGrammar) -> String +``` + +#### `build_grammar_vec()` +**Location**: [`src/utils/guidance.rs:316`](src/utils/guidance.rs:316) + +Builds grammar vec based on constraint and tool presence: + +```rust +pub fn build_grammar_vec( + constraint_grammars: Vec, + tool_grammar: Option, + tool_choice_required: bool, +) -> Vec +``` + +### BuildLLG Factory Functions + +#### `build_llg_factory()` +**Location**: [`src/utils/guidance.rs:449`](src/utils/guidance.rs:449) + +Builds a ParserFactory for llguidance: + +```rust +pub fn build_llg_factory( + tokenizer: Tokenizer, + vocab_size: Option, +) -> Result> +``` + +#### `load_toktrie_from_path()` +**Location**: [`src/utils/guidance.rs:471`](src/utils/guidance.rs:471) + +Loads a TokTrie from a file path: + +```rust +pub fn load_toktrie_from_path(path: impl AsRef) -> Result +``` + +### GuidanceState Methods + +#### `new_from_grammar()` +**Location**: [`src/utils/guidance.rs:425-439`](src/utils/guidance.rs:425-439) + +Creates a new GuidanceState from a constraint: + +```rust +pub fn new_from_grammar(factory: Arc, grammar: &TopLevelGrammar) -> Result +``` + +**Flow**: +1. `factory.create_parser(grammar)?` → Parser +2. `Matcher::new(Ok(parser))` → Matcher +3. Initialize `llm_tokens`, `llm_bytes`, `slicer_cache` + +**Logging**: +- `DEBUG`: Constraint type +- `DEBUG`: Grammar converted +- `INFO`: GuidanceState created successfully + +#### `validate_token()` +**Location**: [`src/utils/guidance.rs:490-500`](src/utils/guidance.rs:490-500) + +Validates a single token without consuming: + +```rust +pub fn validate_token(&mut self, token: u32) -> bool +``` + +**Returns**: `true` if valid, `false` if rejected + +**Logging**: +- `DEBUG`: Token rejected by grammar (if invalid) + +#### `commit_token()` +**Location**: [`src/utils/guidance.rs:455-467`](src/utils/guidance.rs:455-467) + +Commits a token to the grammar state: + +```rust +pub fn commit_token(&mut self, token: u32) -> Result<()> +``` + +**Flow**: +1. `matcher.consume_token(token)?` +2. `llm_tokens.push(token)` +3. `llm_bytes += 4` (approximate bytes per token) + +**Logging**: +- `TRACE`: Token consumed successfully + +#### `compute_mask_or_eos()` +**Location**: [`src/utils/guidance.rs:503-505`](src/utils/guidance.rs:503-505) + +Computes valid token mask or EOS set: + +```rust +pub fn compute_mask_or_eos(&mut self) -> Result +``` + +**Returns**: `SimpleVob` with valid token indices + +**Logging**: +- `TRACE`: Mask computed with N valid tokens + +#### `consume_ff_tokens()` +**Location**: [`src/utils/guidance.rs:516-536`](src/utils/guidance.rs:516-536) + +Consumes fast-forward tokens guaranteed by grammar: + +```rust +pub fn consume_ff_tokens(&mut self) -> Result, anyhow::Error> +``` + +**Flow**: +1. `matcher.compute_ff_tokens()` → Vec +2. For each token: `consume_token()` + `llm_tokens.push()` + `llm_bytes += 4` + +**Returns**: Vec of consumed FF tokens + +**Logging**: +- `DEBUG`: consume_ff_tokens() called +- `DEBUG`: compute_ff_tokens() returned N tokens +- `DEBUG`: Successfully consumed N tokens + +#### `rollback_to()` +**Location**: [`src/utils/guidance.rs:544-552`](src/utils/guidance.rs:544-552) + +Rolls back to a previous state: + +```rust +pub fn rollback_to(&mut self, token_pos: usize, byte_pos: usize) -> Result<()> +``` + +**Flow**: +1. Calculate `tokens_to_rollback = llm_tokens.len() - token_pos` +2. `matcher.rollback(tokens_to_rollback)?` +3. `llm_tokens.truncate(token_pos)` +4. `llm_bytes = byte_pos` + +**Logging**: +- `DEBUG`: Rollback N tokens successful + +### ModelRunner Methods + +#### `validate_sequence_for_grammar()` +**Location**: [`src/core/runner.rs:1597-1604`](src/core/runner.rs:1597-1604) + +Validates entire sequence against grammar: + +```rust +pub fn validate_sequence_for_grammar( + &self, + seq_id: usize, + output_ids: &[u32] +) -> Option +``` + +**Returns**: `Some(valid_token_count)` or `None` if no constraint + +**Flow**: +1. Get GuidanceState for seq_id +2. Call `state.validate_tokens(output_ids)` +3. Map Result → Option + +**Logging**: +- None (internal operation) + +#### `rollback_sequence_for_guidance()` +**Location**: [`src/core/runner.rs:1607-1614`](src/core/runner.rs:1607-1614) + +Rolls back guidance state for a sequence: + +```rust +pub fn rollback_sequence_for_guidance( + &self, + seq_id: usize, + target_tokens: usize +) -> Result<()> +``` + +**Flow**: +1. Get GuidanceState for seq_id +2. Calculate `target_bytes = target_tokens * 4` +3. Call `state.rollback_to(target_tokens, target_bytes)` + +**Logging**: +- None (internal operation) + +#### `consume_ff_tokens()` +**Location**: [`src/core/runner.rs:1618-1628`](src/core/runner.rs:1618-1628) + +Consumes FF tokens for a sequence: + +```rust +pub fn consume_ff_tokens(&self, seq_id: usize) -> Result> +``` + +**Returns**: FF tokens consumed + +**Flow**: +1. Get GuidanceState for seq_id +2. Call `state.consume_ff_tokens()` +3. Map errors to candle_core::Error + +**Logging**: +- None (internal operation) + +### BlockManager Methods + +#### `rollback_to_seq_tokens()` +**Location**: [`src/core/block_manager.rs:946-1005`](src/core/block_manager.rs:946-1005) + +Rolls back sequence to token position: + +```rust +pub fn rollback_to_seq_tokens( + &mut self, + seq: &mut Sequence, + target_tokens: usize +) -> Result<()> +``` + +**Flow**: +1. Calculate `target_blocks = target_tokens.div_ceil(self.block_size)` +2. Calculate `blocks_to_release = current_blocks - target_blocks` +3. Release blocks from end +4. Update `seq.num_cached_tokens` +5. Clean up prefix cache entries +6. Invalidate Mamba prefix hashes + +**Logging**: +- None (internal operation) + +--- + +## 5. USAGE EXAMPLES + +### Example 1: Enable Tool Grammar Generation + +**CLI**: +```bash +./vllm-rs --enable-tool-grammar --allow-constraint-api +``` + +**In code**: +```rust +let econfig = EngineConfig::new( + // ... other params ... + allow_constraint_api: false, + enable_tool_grammar: true, // Auto-generate tool grammar +); +``` + +When enabled, the system will: +1. Build Lark grammar from `resolved_tools` via [`build_json_tool_lark_grammar()`](src/tools/schema.rs:87) +2. Embed all tool schemas as `%json` directives +3. Make tool calls optional via `start: (TEXT | tool_call)+` (allows mid-conversation tool calls) + +### Example 2: Structured Outputs (OpenAI-style) + +There are two equivalent ways to specify structured outputs: + +**Top-level format** (recommended for convenience): +```json +{ + "messages": [{"role": "user", "content": "Generate a user profile"}], + "structured_outputs": { + "json": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + }, + "required": ["name", "age"] + } + } +} +``` + +**OpenAI-compatible format** (via `extra_body`): +```json +{ + "messages": [{"role": "user", "content": "Generate a user profile"}], + "extra_body": { + "structured_outputs": { + "json": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + }, + "required": ["name", "age"] + } + } + } +} +``` + +Both formats produce identical results. The top-level format is more convenient for direct API calls, while `extra_body` maintains OpenAI compatibility. + +### Example 3: Response Format (OpenAI-compatible) + +```json +{ + "messages": [{"role": "user", "content": "Provide a mathematical reasoning"}], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "math_reasoning", + "schema": { + "type": "object", + "properties": { + "steps": {"type": "array", "items": {"type": "string"}}, + "final_answer": {"type": "string"} + }, + "required": ["steps", "final_answer"] + } + } + } +} +``` + +### Example 4: Custom Lark Grammar + +Using the legacy `constraint` field with `constraint_type`: + +```json +{ + "messages": [{"role": "user", "content": "Generate a phone number"}], + "constraint": "start: 'Hello' _WS? 'World' _WS? '!'", + "constraint_type": "lark" +} +``` + +Or via structured_outputs: + +```json +{ + "messages": [{"role": "user", "content": "Generate a date"}], + "structured_outputs": { + "grammar": "start: date\\n date: year \"-\" month \"-\" day\\n year: /[0-9]{4}/\\n month: /[0-9]{2}/\\n day: /[0-9]{2}/" + } +} +``` + +### Example 5: Regular Expression Constraint + +Using the legacy `constraint` field: + +```json +{ + "messages": [{"role": "user", "content": "Generate a number"}], + "constraint": "^number\\s\\d{3}-\\d{3}-\\d{4}$", + "constraint_type": "regex" +} +``` + +Or via structured_outputs: + +```json +{ + "messages": [{"role": "user", "content": "Generate a number"}], + "structured_outputs": { + "regex": "^number\\s\\d{3}-\\d{3}-\\d{4}$" + } +} +``` + +### Example 6: Choice/Enum Constraint + +```json +{ + "messages": [{"role": "user", "content": "Classify this sentiment"}], + "structured_outputs": { + "choice": ["positive", "negative", "neutral"] + } +} +``` + +--- + +## 6. MATHEMATICAL FOUNDATIONS + +### Token Validation Probability + +The llguidance matcher computes the probability of each token being valid given the current grammar state: + +``` +P(token | grammar_state) = + 1.0 if token ∈ valid_tokens(grammar_state) + 0.0 otherwise +``` + +### FF Token Computation + +Fast-forward tokens are computed by exploring the grammar automaton: + +``` +FF_tokens = longest_prefix(w) where: + w ∈ Σ* (input alphabet) + ∧ δ(q0, w) ∈ F (final states) + ∧ ∀prefix p of w: δ(q0, p) defined +``` + +### Rollback Cost + +The rollback operation has O(n) complexity where n = tokens_to_rollback: + +``` +rollback_cost = O(n) + O(k) + where n = tokens to rollback + where k = bytes to adjust +``` + +### Mask Computation Complexity + +``` +mask_computation = O(|V| * |grammar_rules|) + where |V| = vocabulary size + where |grammar_rules| = number of grammar rules +``` + +With caching (SlicerCache), repeated queries at the same position are O(1). + +--- + +## 7. ROLLBACK MECHANICS + +### State Consistency Guarantees + +Before rollback: +- `Sequence.token_ids`: All tokens including invalid ones +- `Sequence.block_table`: All allocated blocks +- `Sequence.num_cached_tokens`: Full cached count +- `GuidanceState.llm_tokens`: All committed tokens +- `GuidanceState.matcher`: Parser state at invalid position +- `BlockManager.prefix_cache`: All cached entries + +After rollback: +- `Sequence.token_ids`: Truncated to valid position +- `Sequence.block_table`: Truncated to valid blocks +- `Sequence.num_cached_tokens`: Block-aligned value +- `GuidanceState.llm_tokens`: Truncated to valid position +- `GuidanceState.matcher`: Parser state at valid position +- `BlockManager.prefix_cache`: Cleaned for evicted blocks + +### Rollback Steps + +1. **Save Snapshot**: Store current state for potential recovery +2. **Truncate Sequence**: Remove invalid tokens from token_ids +3. **Truncate Blocks**: Remove blocks beyond target position +4. **Release KV Cache**: Decrement block reference counts +5. **Clean Prefix Cache**: Remove entries for released blocks +6. **Invalidate Mamba**: Remove Mamba prefix mappings +7. **Rollback Matcher**: Reset grammar state to valid position +8. **Reset Status**: Mark sequence as Running for reprocessing + +### Error Handling + +If rollback fails: +- Log error with full state dump +- Mark sequence as Finished to release resources +- Do NOT attempt partial rollback + +--- + +## 8. PERFORMANCE CONSIDERATIONS + +### Positive Impacts + +1. **Reduced re-sampling**: FF tokens skip ahead to valid continuations +2. **Smaller logit space**: Mask reduces candidates from vocab_size to valid set +3. **Early rejection**: Validation catches failures before streaming + +### Tradeoffs + +1. **Memory overhead**: GuidanceState stored per-sequence (~100KB) +2. **Parsing overhead**: StreamToolParser tracks incremental state +3. **Rollback cost**: O(n) where n = tokens to rollback + +### Recommendations + +- Use `--enable-tool-grammar` for tool-heavy workloads +- Use structured_outputs for complex JSON schemas +- Monitor `guidance_failed` counter for constraint issues + +--- + +## 9. LOGGING LEVELS + +| Level | Use Case | Example | +|-------|----------|---------| +| `TRACE` | Token-level operations | "Token 123 consumed successfully" | +| `DEBUG` | Constraint processing | "Building Lark grammar from choice options" | +| `INFO` | State changes | "GuidanceState created successfully" | +| `WARN` | Validation failures | "Token 456 rejected by grammar" | +| `ERROR` | Rollback failures | "Guidance rollback failed: ..." | + +--- + +## 10. CLI FLAGS REFERENCE + +| Flag | Default | Description | +|------|---------|-------------| +| `--allow-constraint-api` | `false` | Allow client to submit structured_outputs/response_format | +| `--enable-tool-grammar` | `false` | Automatically build LLG grammar from tools | +| `--prefix-cache` | `false` | Enable prefix caching | +| `--fp8-kvcache` | `false` | Use FP8 quantization for KV cache | + +--- + +## 11. TROUBLESHOOTING + +### Issue: "Guidance mask length is 0" + +**Cause**: Constraint is too restrictive, no tokens valid + +**Solution**: +- Check constraint grammar/schema +- Enable `allow_constraint_api` for debugging +- Remove or set `grammar: null` for non-constrained generation + +### Issue: "structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag" + +**Cause**: Multiple constraint fields specified in structured_outputs + +**Solution**: Only specify one constraint type in request + +### Issue: "Unsupported response_format type" + +**Cause**: response_format.type is not "json_schema" + +**Solution**: Use only supported types or use structured_outputs instead + +### Issue: "Tool buffering exceeded timeout" + +**Cause**: Streaming tool call taking too long to complete + +**Solution**: +- Increase `VLLM_RS_TOOL_BUFFER_TIMEOUT_SECS` +- Check for malformed tool call JSON +- Verify tool parser configuration + +--- + +## 12. TESTING & VALIDATION + +### Testing Grammar-Driven Guidance via curl + +#### Example 1: Phone Number Format (Regex Constraint) + +**Enable client constraints**: +```bash +vllm-rs --m unsloth/Qwen3-30B-A3B-Instruct-2507-GGUF --f Qwen3-30B-A3B-Instruct-2507-Q4_K_M.gguf \ + --ui-server --allow-constraint-api +``` + +**Test request** (top-level structured_outputs): +```bash +curl -sXPOST localhost:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [{"role":"user","content":"Generate a phone number"}], + "constraint": "^number:\\s\\s\\d{3}-\\d{3}-\\d{4}\\ndo you want a sandwitch with that\\s\\S{6}", + "constraint_type": "regex" + }' | jq -r '.choices[0].message.content' +``` + +**Expected output**: +``` +number: 123-456-7890 +do you want a sandwitch with that number? +``` + +--- + +#### Example 2: JSON Schema Constraint (Structured Outputs) + +**Test request**: +```bash +curl -sXPOST localhost:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [{"role":"user","content":"Generate a user profile"}], + "structured_outputs": { + "json": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer", "minimum": 0, "maximum": 150}, + "email": {"type": "string", "pattern": "^[a-z]+@[a-z]+\\.[a-z]+$"} + }, + "required": ["name", "age", "email"], + "additionalProperties": false + } + }, + "max_tokens": 500 + }' | jq -r '.choices[0].message.content' +``` + +**Expected output**: JSON with `name` (string), `age` (integer), `email` (string matching pattern) + +--- + +#### Example 3: Tool Grammar Generation (Auto-LLG) + +**Enable tool grammar**: +```bash +vllm-rs --m unsloth/Qwen3-30B-A3B-Instruct-2507-GGUF --f Qwen3-30B-A3B-Instruct-2507-Q4_K_M.gguf \ + --ui-server --enable-tool-grammar --mcp-config ./mcp.json +``` + +**Test request with tools**: +```bash +curl -sXPOST localhost:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [{"role":"user","content":"What is the weather in London?"}], + "tools": [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"} + }, + "required": ["location"] + } + } + }], + "tool_choice": "auto", + "max_tokens": 500 + }' | jq -r '.choices[0].message.content' +``` + +**Expected output**: Tool call in proper format with `name` and `arguments` + +--- + +#### Example 4: Choice/Enum Constraint (Lark Grammar) + +**Test request**: +```bash +curl -sXPOST localhost:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [{"role":"user","content":"Classify this sentiment"}], + "structured_outputs": { + "choice": ["positive", "negative", "neutral"] + }, + "max_tokens": 50 + }' | jq -r '.choices[0].message.content' +``` + +**Expected output**: One of `positive`, `negative`, or `neutral` (quoted string) + +--- + +#### Example 5: Custom Lark Grammar + +**Test request**: +```bash +curl -sXPOST localhost:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [{"role":"user","content":"Generate a date"}], + "structured_outputs": { + "grammar": "start: date\\n date: year \"-\" month \"-\" day\\n year: /[0-9]{4}/\\n month: /[0-9]{2}/\\n day: /[0-9]{2}/" + }, + "max_tokens": 50 + }' | jq -r '.choices[0].message.content' +``` + +**Expected output**: Date in `YYYY-MM-DD` format + +--- + +### Verification Checklist + +For each test, verify: +1. [ ] Response contains only tokens valid per the grammar/constraint +2. [ ] No invalid JSON structure produced +3. [ ] Tool calls follow proper `name`/`arguments` format +4. [ ] Regex patterns matched exactly +5. [ ] Enum choices limited to specified options + +--- + +### Log Messages to Watch For + +| Message | Meaning | +|---------|---------| +| `[llg] Applied constraint to params` | Constraint successfully set from tools | +| `[llg] GuidanceState created successfully` | Grammar parser initialized | +| `[llg] Token X rejected by grammar` | Token validation failed | +| `[llg] Resampled token X consumed by matcher` | Re-sampling worked correctly | +| `[Seq X] Exceeded 3 rollback attempts` | Rolling back too often - check constraint | + +--- + +## 14. TOKEN ID BASED LARK GRAMMAR CALL GRAPH + +### Overview + +When `start_token_ids` and `end_token_ids` are provided to `build_json_tool_lark_grammar()` or `build_xml_tool_lark_grammar()`, the system uses token ID syntax (`<[token_id]>`) instead of string literals in the Lark grammar. + +### Call Graph + +``` +Server Request + │ + ▼ +[server/server.rs:458-466] build_json_tool_lark_grammar() + │ + ├─ tool_config.start_token_ids (e.g., {151657}) + ├─ tool_config.end_token_ids (e.g., {151658}) + │ + ▼ +[tools/schema.rs:87-100] build_json_tool_lark_grammar() + │ + ├─ Accepts start_token_ids: Option<&HashSet> + ├─ Accepts end_token_ids: Option<&HashSet> + │ + ▼ +[tools/schema.rs:118-143] build_json_tool_lark_string() + │ + ├─ if start_token_ids.is_some_and(|ids| !ids.is_empty()): + │ └─ [tools/schema.rs:60-67] lark_special_token(ids) + │ └─ Returns: "<[151657]>" (token ID syntax) + │ + ├─ else: + │ └─ Uses lark_literal(start, start_is_special) + │ └─ Returns: "\"\"" (string literal syntax) + │ + ▼ +[Lark Grammar String] + │ + ├─ Token ID mode: "tool_call: <[151657]> ws json_array ws <[151658]>" + └─ String mode: "tool_call: \"\" ws json_array ws \"\"" +``` + +### lark_special_token() Function + +**Location**: [`src/tools/schema.rs:60-67`](src/tools/schema.rs:60-67) + +```rust +fn lark_special_token(token_ids: &HashSet) -> String { + if token_ids.is_empty() { + return String::new(); + } + // Join multiple token IDs with | + let ids: Vec = token_ids.iter().map(|id| format!("[{}]", id)).collect(); + format!("<{}>", ids.join(",")) +} +``` + +### Example Output + +With token IDs `{151657, 151658}`: +``` +tool_call: <[151657]> ws json_array ws <[151658]> +``` + +Without token IDs (fallback to strings): +``` +tool_call: "" ws json_array ws "" +``` + +### Tests Verifying Token ID Support + +1. **`test_build_json_tool_lark_grammar_qwen3_with_token_ids`** (lines 764-783) + - Verifies that token IDs are converted to `<[token_id]>` syntax + - Checks that the generated grammar contains the correct token IDs + +2. **`test_lark_special_token_single_id`** (lines 785-791) + - Tests single token ID conversion: `<[151657]>` + +3. **`test_lark_special_token_multiple_ids`** (lines 793-800) + - Tests multiple token IDs: `<[151657],[151658]>` + +4. **`test_lark_special_token_empty`** (lines 802-807) + - Tests empty token ID set returns empty string + +--- + +## 13. GRAMMAR CONSTRUCTION DETAILS + +### The `rule_N` Indirection Problem + +**Old behavior** (incorrect): +```lark +start: rule_0 | rule_1 +rule_0: TEXT +TEXT: /(.|[\\n\\r])*/ +rule_1: +tool_call: ... +``` + +This creates an unnecessary level of indirection where: +1. `rule_0` references `TEXT` (which is actually a terminal) +2. `rule_1` is empty and just wraps `tool_call` +3. The `start` rule alternates between these wrappers + +**New behavior** (correct): +```lark +start: TEXT | tool_call +TEXT: /((?s).)*/ # (?s) enables dotall mode +tool_call: ... +``` + +This produces a flat grammar where: +1. `start` directly alternates between `TEXT` and `tool_call` +2. No intermediate `rule_N` wrappers +3. Cleaner, more efficient grammar + +### Implementation + +The fix is implemented in two helper functions: + +1. **`parse_lark_grammar()`**: Extracts the start rule's RHS and remaining rules +2. **`combine_rules()`**: Merges rules while handling duplicates + +### Performance Impact + +- **Smaller grammar size**: No intermediate rule wrappers +- **Faster parsing**: Fewer Earley items to track +- **Lower memory usage**: Simpler grammar structure +- **Better error messages**: Direct alternation is easier to understand + +--- + +--- + +## 15. EOS TOKEN MANDATE FOR FREEFORM GENERATION + +### Why EOS Tokens Are Required + +For freeform TEXT generation (non-constrained), the grammar MUST include an explicit EOS token boundary. Without it: + +1. **Mask Preemption**: The `compute_mask()` function returns token IDs before generation, but the TEXT pattern `/((?s).)*/` allows any character including EOS +2. **No Finite Boundary**: Without an explicit EOS in the grammar, the lexer has no way to know when to stop accepting TEXT tokens +3. **Run-on Generation**: The model continues generating indefinitely until max_tokens is reached + +### Correct TEXT Pattern with EOS + +```lark +start: text_with_eos +text_with_eos: TEXT eos? +TEXT: /(?s:.*)/ +eos: <[248044]> | <[248046]> | <[248048]> | <[248052]> | <[248054]> | <[248050]> +``` + +### Incorrect TEXT Pattern (causes run-on generation) + +```lark +start: text +text: TEXT +TEXT: /(?s:.*)/ +``` + +### Implementation in chat_text_expression_with_eos() + +The function [`chat_text_expression_with_eos()`](src/utils/guidance.rs:485) in guidance.rs properly handles this: + +```rust +pub fn chat_text_expression_with_eos(special_tokens: &SpecialTokens) -> String { + let eos_token_ids = special_tokens.eos_ids(); + + let eos_pattern = if eos_token_ids.is_empty() { + // Fallback to stop="" when no EOS tokens available + r#"start: text +text[stop=""]: /((?s).*?)/"#.to_string() + } else if eos_token_ids.len() == 1 { + format!(r#"start: text_with_eos +text_with_eos: TEXT eos? +TEXT: /(?s:.*)/ +eos: <[{}]>"#, eos_token_ids[0]) + } else { + let ids: Vec = eos_token_ids.iter().map(|id| format!("<[{}]>", id)).collect(); + let eos_alternation = ids.join(" | "); + format!(r#"start: text_with_eos +text_with_eos: TEXT eos? +TEXT: /(?s:.*)/ +eos: {}"#, eos_alternation) + }; + + eos_pattern +} +``` + +### Key Points + +1. **Use `chat_text_expression_with_eos()`** instead of `chat_text_expression()` when freeform TEXT is needed +2. **Always include EOS tokens** in the grammar for unconstrained generation +3. **Avoid `stop=""` patterns** - they don't work reliably with llguidance's lexer +4. **Use `eos?` syntax** to make EOS optional at the end of text + +--- + +## 16. QWEN CODER TOOL PARSING ISSUES + +### Problem: XML Nested Tags in Parameter Values + +Qwen Coder models output tool parameters with XML-style nested tags like: + +```xml +<‌tool_call> +<‌function=edit_file> +<‌parameter=file_path>/tmp/a.rs +<‌parameter=new_string> +fn a() { let x = vec![1,2,3]; } +<‌/function> +<‌/tool_call> +``` + +### The Grammar Challenge + +The current grammar uses regex patterns to match XML content: + +```lark +value_4_0: /[^<]*(<[^\/][^<]*)*?/ +``` + +This pattern: +- **Allows**: Regular text and non-closing angle brackets +- **Fails on**: Content that contains `<` followed by a `/` (closing tag) - **premature termination** +- **Fails on**: Content that contains `<` followed by a letter (opening tag) - **false positive tag detection** + +### Why This Is Fundamentally Broken + +1. **Look-Ahead Limitation**: Earley regex cannot express "match until you see `<‌/parameter>` but allow `<‌function=...>` in between" +2. **Finite Masks**: llguidance precomputes token masks, but nested XML requires unbounded context +3. **No Recursive Grammars**: Lark cannot express recursive XML structures in a way that maps to token masks + +### Current Workarounds + +#### Option A: Conservative Text Matching (Current) +```lark +value: /[^<]*(<[^\/][^<]*)*?/ +``` +- **Pros**: Works for most cases, finite mask possible +- **Cons**: Fails if parameter content contains `<` character + +#### Option B: Allow Any Character Until Strict End +```lark +value: /(?s).*?(?=<‌\/parameter>)/ +``` +- **Pros**: Handles `<` in content +- **Cons**: Requires look-ahead, impossible with finite masks + +#### Option C: Use Token IDs Instead of String Literals +```lark +value: /[^<]*(<[0-9]+[^\/][^<]*)*?/ +``` +- **Pros**: More flexible pattern matching +- **Cons**: Still can't handle nested `<` characters + +### The Real Problem + +``` +<‌parameter=new_string> ← Start of parameter +fn a() { let x = vec![1,2,3]; } ← Contains '<' characters +<‌/parameter> ← End of parameter (but mask sees '<' and thinks it's a tag) +``` + +When the mask encounters `<`, it: +1. Checks if next character is `/` → closing tag +2. Checks if next character is letter → opening tag +3. **Preempts content generation** before the actual `` + +### Recommended Solution: Avoid XML Parameters for Tool Calls + +Instead of nested XML like: + +```lark +<‌function=edit_file> +<‌parameter=file_path>/tmp/a.rs +<‌parameter=new_string>fn a() { let x = vec![1,2,3]; } +<‌/function> +``` + +Use **flat JSON** format: + +```json +{ + "name": "edit_file", + "arguments": { + "file_path": "/tmp/a.rs", + "new_string": "fn a() { let x = vec![1,2,3]; }" + } +} +``` + +### Grammar for JSON Tool Calls (Recommended) + +```lark +start: tool_call +tool_call: "<‌tool_call>" ws json_array ws "<‌/tool_call>" +json_array: "[" obj ("," obj)* "]" +obj: obj_search | obj_edit +obj_search: %json {"type":"object","properties":{...}} +obj_edit: %json {"type":"object","properties":{...}} +ws: /[ \t\r\n]+/ +``` + +This avoids the XML nested tag problem entirely by: +1. Using `%json` directives for structured parameter schemas +2. Not exposing parameter tags in the grammar +3. Letting the parser validate JSON structure instead of regex + +### Summary + +| Issue | Current State | Recommendation | +|-------|--------------|----------------| +| Nested XML tags | Cannot be expressed in finite mask grammar | Use JSON instead | +| `<` in parameter values | Causes premature termination | Avoid XML format | +| Look-ahead parsing | Not supported by llguidance lexer | Use simpler grammar structures | +| | + +Last updated: 2026-03-07 \ No newline at end of file diff --git a/example/special-tokens-extraction/Cargo.toml b/example/special-tokens-extraction/Cargo.toml new file mode 100644 index 00000000..b6945fcf --- /dev/null +++ b/example/special-tokens-extraction/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "special-tokens-extraction" +version = "0.1.0" +edition = "2024" + +[dependencies] +vllm-rs = { path = "../.." } diff --git a/example/special-tokens-extraction/README.md b/example/special-tokens-extraction/README.md new file mode 100644 index 00000000..2688520e --- /dev/null +++ b/example/special-tokens-extraction/README.md @@ -0,0 +1,195 @@ +# Special Tokens Extraction Tool + +This example tool helps developers analyze tokenizer special tokens visually using the `SpecialTokens` module from vllm.rs. + +## Purpose + +When integrating new models or tokenizer configurations, it's essential to understand: +- Which token IDs correspond to special tokens (EOS, BOS, PAD, TOOL, etc.) +- How the tokenizer encodes model-specific special tokens +- Whether custom token rules are needed for new model formats + +This tool extracts and displays all special tokens from a tokenizer file, making it easy to visualize and verify the token mapping. + +## Usage + +### Basic Usage + +```bash +# Using default tokenizer.json in current directory +cargo run --example special-tokens-extraction + +# Using a custom tokenizer path +cargo run --example special-tokens-extraction -- /path/to/tokenizer.json +``` + +### Example Output + +``` +=== Testing Tokenizer Library === + +Successfully loaded tokenizer from: tokenizer.json +Total added tokens processed. + +--- EOS Tokens --- +EOS: id=2 token= +EOS: id=128001 token=<|end_of_text|> +EOS IDs: [2, 128001] +EOS Strings: ["", "<|end_of_text|>"] + +--- PAD Tokens --- +PAD: id=0 token= + +--- BOS Tokens --- +BOS: id=1 token= + +--- TOOL Tokens --- +TOOL: id=151657 token= + +--- ROLE Tokens --- +ROLE: id=128007 token=_ROLE +ROLE: id=128008 token=ROLE_ + +--- MASK Tokens --- +MASK: id=32000 token= + +--- REASONING Tokens --- +REASONING: id=1 token= + +--- OTHER Tokens --- +OTHER: id=0 token= +OTHER: id=1 token= +OTHER: id=2 token= +OTHER: id=3 token= +``` + +## Token Categories + +The tool classifies tokens into the following categories: + +| Category | Description | Example Tokens | +|----------|-------------|----------------| +| `EOS` | End of sequence tokens | ``, `<|end_of_text|>`, `` | +| `PAD` | Padding tokens | ``, `` | +| `BOS` | Beginning of sequence tokens | ``, `<|start_of_turn|>` | +| `SEP` | Separator tokens | ``, `<|separator|>` | +| `CLS` | Classification tokens | ``, `[CLS]` | +| `MASK` | Mask tokens for masking | ``, `[MASK]` | +| `TOOL` | Tool-related tokens | ``, `<|tool|>` | +| `FUNCTION` | Function tokens | ``, `<|function|>` | +| `PARAMETER` | Parameter tokens | ``, `<|parameter|>` | +| `ROLE` | Role tokens (chat templates) | `ROLE`, `ROLE_`, `<|role|>` | +| `CONTENT_TYPE` | Content type tokens | ``, `<|content_type|>` | +| `REASONING` | Reasoning/thinking tokens | ``, ``, `` | +| `OTHER` | Unmatched tokens | ``, etc. | + +## Understanding SpecialTokens Rules + +The `SpecialTokens` struct uses a flexible matching system based on `MatchRule`: + +### MatchRule Types + +```rust +pub enum MatchRule { + Exact(String), // Exact match: "" matches "" + StartsWith(String), // Prefix match: "<|end" matches "<|end_of_text|>" + Contains(String), // Substring match: "tool" matches "" + And(Box, Box), // Both rules must match + Or(Box, Box), // Either rule must match + Not(Box), // Rule must NOT match +} +``` + +### Default Rules + +The `default_rules()` function in `src/utils/special_tokens.rs` defines matching rules for all categories. + +## Customizing Token Rules + +To add support for new token patterns: + +1. Edit `src/utils/special_tokens.rs` +2. Add new rules to `default_rules()`: + +```rust +// Example: Add custom thinking token +(MatchRule::Contains("custom_thinking".to_string()), Category::Reasoning), + +// Example: Custom tool start token +(MatchRule::Exact("".to_string()), Category::Tool), +``` + +3. Test with the extraction tool: +```bash +cargo run --example special-tokens-extraction -- /path/to/tokenizer.json +``` + +## Integration with vllm.rs + +The `SpecialTokens` struct is used throughout vllm.rs: + +### Engine Initialization + +```rust +// src/core/engine.rs:474 +let special_tokens = Arc::new(SpecialTokens::new(&tokenizer)); +``` + +### Scheduler Usage + +```rust +// src/core/scheduler.rs:135 +eos_token_id: special_tokens.eos_ids(), +``` + +### Guidance/LLG Usage + +```rust +// src/utils/guidance.rs:485 +pub fn chat_text_expression_with_eos(special_tokens: &SpecialTokens) -> String { + let eos_token_ids = special_tokens.eos_ids(); + // ... build TEXT pattern with EOS tokens +} +``` + +## Troubleshooting + +### Common Issues + +1. **Empty token list** + - Check that the tokenizer file path is correct + - Verify the tokenizer has added tokens (some tokenizers use vocab tokens) + +2. **Tokens not classified correctly** + - Add custom rules in `src/utils/special_tokens.rs` + - Use `search(None, Some("substring"))` to debug token matching + +3. **Token ID collisions** + - The `SpecialTokens::new()` implementation deduplicates by token ID + - Check with `search(Some(token_id), None)` to verify uniqueness + +### Debugging with the Search API + +```rust +// Search by ID +let matches = special_tokens.search(Some(151657), None); +for m in matches { + println!("ID 151657: {} -> {}", m.category, m.content); +} + +// Search by substring +let matches = special_tokens.search(None, Some("tool")); +for m in matches { + println!("Contains 'tool': {} -> {}", m.category, m.content); +} +``` + +## File Reference + +- **Source**: `src/utils/special_tokens.rs` +- **Example**: `example/special-tokens-extraction/src/main.rs` +- **Tests**: `src/utils/special_tokens.rs` (test module at end of file) + +## License + +This example is part of the vllm.rs project. diff --git a/example/special-tokens-extraction/src/main.rs b/example/special-tokens-extraction/src/main.rs new file mode 100644 index 00000000..63f3a35e --- /dev/null +++ b/example/special-tokens-extraction/src/main.rs @@ -0,0 +1,169 @@ +use vllm_rs::utils::special_tokens::SpecialTokens; +use std::env; + +fn main() { + println!("=== Testing Tokenizer Library ===\n"); + + // Path to our mock tokenizer file + let args: Vec = env::args().collect(); + let tokenizer_path = if args.len() > 1 { + args[1].clone() + } else { + "./tokenizer.json".to_string() + }; + + let special = SpecialTokens::new_from_file(&tokenizer_path); + + let reasoning_matches = special.search(None, Some("tool"), None, None); + for m in reasoning_matches { + println!("Search Result - Category: {:?}, ID: {}, string='{}'", m.category, m.id, m.string()); + // Also show the hex representation + let hex: Vec = m.content.iter().map(|b| format!("0x{:02x}", b)).collect(); + println!(" Hex: {:?}", hex); + } + + println!("\nSuccessfully loaded tokenizer from: {}", tokenizer_path); + println!("Total tokens processed: {}\n", special.all_tokens().len()); + + // Test Eos + println!("--- EOS Tokens ---"); + for token in special.eos() { + println!("EOS: id={} category={:?} source={:?} string={}", token.id, token.category, token.source, token.string()); + } + println!("EOS IDs: {:?}", special.eos_ids()); + println!("EOS Strings: {:?}", special.eos_strings()); + println!(); + + // Test Pad + println!("--- PAD Tokens ---"); + for token in special.pad() { + println!("PAD: id={} category={:?} source={:?} string={}", token.id, token.category, token.source, token.string()); + } + println!(); + + // Test Bos + println!("--- BOS Tokens ---"); + for token in special.bos() { + println!("BOS: id={} category={:?} source={:?} string={}", token.id, token.category, token.source, token.string()); + } + println!(); + + // Test Tool + println!("--- TOOL Tokens ---"); + for token in special.tool() { + println!("TOOL: id={} category={:?} source={:?} string={}", token.id, token.category, token.source, token.string()); + } + println!(); + + // Test Role + println!("--- ROLE Tokens ---"); + for token in special.role() { + println!("ROLE: id={} category={:?} source={:?} string={}", token.id, token.category, token.source, token.string()); + } + println!(); + + // Test Mask + println!("--- MASK Tokens ---"); + for token in special.mask() { + println!("MASK: id={} category={:?} source={:?} string={}", token.id, token.category, token.source, token.string()); + } + println!(); + + // Test Reasoning + println!("--- REASONING Tokens ---"); + for token in special.reasoning() { + println!("REASONING: id={} category={:?} source={:?} string={}", token.id, token.category, token.source, token.string()); + } + println!(); + + // Test Other (Tokens that didn't match specific rules above, e.g., ) + println!("--- OTHER Tokens ---"); + for token in special.other() { + println!("OTHER: id={} category={:?} source={:?} string={}", token.id, token.category, token.source, token.string()); + } + println!(); + + // Test content_type + println!("--- CONTENT_TYPE Tokens ---"); + for token in special.content_type() { + println!("CONTENT_TYPE: id={} category={:?} source={:?} string={}", token.id, token.category, token.source, token.string()); + } + println!(); + + // Test Function + println!("--- FUNCTION Tokens ---"); + for token in special.function() { + println!("FUNCTION: id={} category={:?} source={:?} string={}", token.id, token.category, token.source, token.string()); + } + println!(); + + // Test Parameter + println!("--- PARAMETER Tokens ---"); + for token in special.parameter() { + println!("PARAMETER: id={} category={:?} source={:?} string={}", token.id, token.category, token.source, token.string()); + } + println!(); + + // Test Sep + println!("--- SEP Tokens ---"); + for token in special.sep() { + println!("SEP: id={} category={:?} source={:?} string={}", token.id, token.category, token.source, token.string()); + } + println!(); + + // Test Cls + println!("--- CLS Tokens ---"); + for token in special.cls() { + println!("CLS: id={} category={:?} source={:?} string={}", token.id, token.category, token.source, token.string()); + } + println!(); + + // Test tool start/end helpers + println!("--- Tool Start IDs ---"); + println!("{:?}", special.tool_start_ids()); + println!("--- Tool End IDs ---"); + println!("{:?}", special.tool_end_ids()); + + // Additional search examples + println!("\n=== Additional Search Examples ===\n"); + + // Search by category only + println!("--- Search all EOS tokens ---"); + let eos_results = special.search(None, None, Some(vllm_rs::utils::special_tokens::Category::Eos), None); + for token in eos_results { + println!(" EOS: id={} category={:?} source={:?} string={}", token.id, token.category, token.source, token.string()); + } + + // Search by source (Added tokens) + println!("\n--- Search Added tokens ---"); + let added_results = special.search(None, None, None, Some(vllm_rs::utils::special_tokens::VocabSource::Added)); + for token in added_results { + println!(" Added: id={} category={:?} source={:?} string={}", token.id, token.category, token.source, token.string()); + } + + // Search by ID + println!("\n--- Search by ID 2 ---"); + let id_results = special.search(Some(2), None, None, None); + for token in id_results { + println!(" ID 2: category={:?} source={:?} string={}", token.category, token.source, token.string()); + } + + // Search by substring and category combined + println!("\n--- Search tokens containing 'end' in EOS category ---"); + let combined_results = special.search(None, Some("end"), Some(vllm_rs::utils::special_tokens::Category::Eos), None); + for token in combined_results { + println!(" Token: id={} category={:?} source={:?} string={}", token.id, token.category, token.source, token.string()); + } + + // Get all special tokens (Special or Added source) + println!("\n--- All Special/Added tokens ---"); + for token in special.all_special() { + println!(" id={} category={:?} source={:?} string={}", token.id, token.category, token.source, token.string()); + } + + // Print all tokens with full details + println!("\n=== All Tokens (Full Details) ==="); + for token in special.all_tokens() { + println!("id={} category={:?} source={:?} string={}", token.id, token.category, token.source, token.string()); + } +} diff --git a/src/api.rs b/src/api.rs index 83b3bbf4..26870fd0 100644 --- a/src/api.rs +++ b/src/api.rs @@ -151,6 +151,8 @@ impl EngineBuilder { None, self.pd_server_prefix_cache_ratio, self.pd_client_prefix_cache_ratio, + false, // allow_constraint_api + false, // enable_tool_grammar ); let dtype = self.dtype.clone().map(dtype_to_str); diff --git a/src/core/block_manager.rs b/src/core/block_manager.rs index 1f9bfefa..dcafc49e 100644 --- a/src/core/block_manager.rs +++ b/src/core/block_manager.rs @@ -571,6 +571,11 @@ impl BlockManager { .map_or(0, |cache| cache.cached_blocks()) } + /// Get a reference to the runners Arc + pub fn get_runners(&self) -> &Arc> { + &self.runners + } + /// Returns how many tokens of `seq` are already cached in the prefix cache. /// Used to decide whether to do local prefill vs transfer to PD server. pub fn get_prefix_cache_match_tokens(&mut self, seq: &Sequence) -> usize { @@ -689,6 +694,11 @@ impl BlockManager { self.block_size } + /// Get the block size + pub fn block_size(&self) -> usize { + self.block_size + } + pub fn get_cpu_swap_usage(&self) -> f32 { let total_cpu_blocks = self.cpu_blocks.len(); (total_cpu_blocks - self.free_cpu_block_ids.len()) as f32 / total_cpu_blocks as f32 @@ -951,4 +961,67 @@ impl BlockManager { } } } + + /// Rollback a sequence to a specific token position, releasing blocks beyond that point. + /// This is used for speculative decoding mismatch recovery. + pub fn rollback_to_seq_tokens(&mut self, seq: &mut Sequence, target_tokens: usize) -> Result<()> { + let current_tokens = seq.len(); + if target_tokens >= current_tokens { + return Ok(()); // Nothing to rollback + } + + // Calculate how many blocks to release + let target_blocks = target_tokens.div_ceil(self.block_size); + let blocks_to_release = current_tokens.div_ceil(self.block_size) - target_blocks; + + if blocks_to_release > 0 { + // Release blocks from the end + let released: Vec = seq.block_table.drain(target_blocks..).collect(); + for &block_id in &released { + let block_id_usize = block_id as usize; + self.decrement_block_ref(block_id_usize); + } + } + + // Update cached token count + let target_full_blocks = target_tokens / self.block_size; + seq.num_cached_tokens = target_full_blocks * self.block_size; + + // Update prefix cache if enabled + if self.prefix_cache.is_some() { + // Extract prefix_cache to avoid borrow conflicts + let mut prefix_cache = self.prefix_cache.take().unwrap(); + + // Calculate which hashes correspond to released blocks + let target_full_blocks = target_tokens / self.block_size; + let current_full_blocks = current_tokens / self.block_size; + + // Collect hashes to remove + let mut hashes_to_remove = Vec::new(); + + for block_idx in target_full_blocks..current_full_blocks { + // Get the block_id for this position before release + if let Some(&block_id_u32) = seq.block_table.get(block_idx) { + let block_id = block_id_u32 as usize; + + // Find the hash associated with this block_id + if let Some(hash) = prefix_cache.hash_for_block(block_id) { + hashes_to_remove.push(hash); + } + } + } + + // Remove hashes from prefix cache and mamba mappings + for hash in hashes_to_remove { + if prefix_cache.remove_hash(&hash).is_some() { + self.invalidate_mamba_prefix_hash(hash); + } + } + + // Put prefix_cache back + self.prefix_cache = Some(prefix_cache); + } + + Ok(()) + } } diff --git a/src/core/engine.rs b/src/core/engine.rs index 0812b5a3..585bad9c 100644 --- a/src/core/engine.rs +++ b/src/core/engine.rs @@ -19,8 +19,9 @@ use crate::tools::Tool; use crate::transfer::PdRole; use crate::transfer::Transfer; use crate::utils::chat_template::Message; -use crate::utils::config::{EngineConfig, EosTokenId, ModelType, SamplingParams}; -use crate::utils::guidance::load_toktrie_from_path; +use crate::utils::config::{EngineConfig, ModelType, SamplingParams}; +use crate::utils::special_tokens::SpecialTokens; +use crate::utils::guidance::{build_llg_factory, load_toktrie_from_path}; use crate::utils::heartbeat::heartbeat_worker; use crate::utils::image::{get_image_config, ImageData, ImageProcessConfig}; use crate::utils::kvcache_allocator::KVCacheAllocator; @@ -100,6 +101,9 @@ pub struct LLMEngine { pub model_type: ModelType, pub tool_config: ToolConfig, pub img_cfg: Option, + /// SpecialTokens parsed once at engine initialization + /// Contains EOS, BOS, and other special token IDs and their string representations + pub special_tokens: Arc, } impl LLMEngine { @@ -107,9 +111,23 @@ impl LLMEngine { pub fn new(econfig: &EngineConfig, dtype: DType) -> Result>> { let (model_pathes, is_gguf, mut config, config_tokenizer, tokenizer, mut generation_cfg) = init_config_tokenizer(econfig)?; - let toktrie = load_toktrie_from_path(&model_pathes.get_tokenizer_filename()).map(Arc::new); + let toktrie = match load_toktrie_from_path(&model_pathes.get_tokenizer_filename()) { + Ok(trie) => Some(Arc::new(trie)), + Err(e) => { + crate::log_warn!("Failed to load tokenizer trie: {}", e); + None + } + }; + let llg_factory = match build_llg_factory(tokenizer.clone(), config.vocab_size) { + Ok(f) => Some(f), + Err(e) => { + crate::log_warn!("Failed to build llguidance factory: {}", e); + None + } + }; + if toktrie.is_none() { - crate::log_warn!("Guided decoding disabled: tokenizer trie unavailable."); + crate::log_warn!("Guided decoding (legacy) disabled: tokenizer trie unavailable."); } let stop_flag = Arc::new(AtomicBool::new(false)); @@ -118,13 +136,16 @@ impl LLMEngine { prepare_engine_config(econfig, &config, &config_tokenizer, &mut generation_cfg); config.fp8_kvcache = econfig.fp8_kvcache; - // In case config file missing bos and eos configuratioin + // Initialize SpecialTokens early to use for EOS token extraction + let special_tokens = SpecialTokens::new(&tokenizer); + + // In case config file missing bos and eos configuration config.apply_generation_cfg(generation_cfg.as_ref()); if config.eos_token_id.is_none() { - if let Some(eos) = &config_tokenizer.eos_token { - if let Some(token) = tokenizer.get_vocab(true).get(eos).copied() { - config.eos_token_id = Some(EosTokenId::Single(token)); - }; + // Extract EOS tokens from SpecialTokens (single source of truth) + let eos_ids: Vec = special_tokens.eos_ids(); + if !eos_ids.is_empty() { + config.eos_token_id = Some(eos_ids); } } assert!( @@ -175,43 +196,48 @@ impl LLMEngine { let reporter: Arc>> = Arc::new(RwLock::new(Box::new(ProgressReporter::new(0)))); let handle = progress_worker(1, config.num_hidden_layers, &reporter); - let vb = VarBuilderX::new(&model_pathes, is_gguf, dtype, &device)?; - let transfer = if let Some(p_cfg) = &econfig.pd_config { - Some(Arc::new(Transfer::new( - p_cfg.clone(), - 0, - model_loaded.clone(), - stop_flag.clone(), - )?)) - } else { - None - }; - - let mut model_runner = ModelRunner::new( - model_type.clone(), - &vb, - #[cfg(not(feature = "nccl"))] - Rc::new(Comm::default()), - #[cfg(feature = "nccl")] - Rc::new( - Comm::from_rank( - device.as_cuda_device().unwrap().cuda_device(), + let mut model_runner = { + let _guard = candle_core::InferenceMode::enter(); + let vb = VarBuilderX::new(&model_pathes, is_gguf, dtype, &device)?; + let transfer = if let Some(p_cfg) = &econfig.pd_config { + Some(Arc::new(Transfer::new( + p_cfg.clone(), 0, - 1, - Id::new().unwrap(), - ) - .unwrap(), - ), - &mut econfig, - &config, - dtype, - is_rope_i, - device.clone(), - reporter, - transfer, - toktrie.clone(), - None, - )?; + model_loaded.clone(), + stop_flag.clone(), + )?)) + } else { + None + }; + + let runner = ModelRunner::new( + model_type.clone(), + &vb, + #[cfg(not(feature = "nccl"))] + Rc::new(Comm::default()), + #[cfg(feature = "nccl")] + Rc::new( + Comm::from_rank( + device.as_cuda_device().unwrap().cuda_device(), + 0, + 1, + Id::new().unwrap(), + ) + .unwrap(), + ), + &mut econfig, + &config, + dtype, + is_rope_i, + device.clone(), + reporter, + transfer, + llg_factory.clone(), + None, + )?; + drop(vb); + runner + }; if !is_pd_server { //No graph capture for PD server @@ -376,29 +402,29 @@ impl LLMEngine { econfig.max_model_len = Some(32768); } let runners = Arc::new(RwLock::new(runners)); - let mut scheduler = Scheduler::new(runners.clone(), &econfig, &config); - // Initialize tool call end tokens for detection based on model type. - let mut tool_config = ToolConfig::for_model_type(&model_type); - tool_config.validate_with_tokenizer(&tokenizer, &model_type); - let tool_call_start_ids = tool_config.tool_call_start_ids(&tokenizer); - let tool_call_end_ids = tool_config.tool_call_end_ids(&tokenizer); + let special_tokens = Arc::new(special_tokens); + let mut scheduler = Scheduler::new(runners.clone(), &econfig, &config, special_tokens.clone()); + + // Initialize tool call end tokens using SpecialTokens for idiomatic access + let tool_call_start_ids: Vec = special_tokens.tool_start_ids(); + let tool_call_end_ids: Vec = special_tokens.tool_end_ids(); if !tool_call_start_ids.is_empty() { scheduler.set_tool_call_start_tokens(tool_call_start_ids.clone()); log_info!( - "Tool call start token IDs set to: {:?}", + "Tool call start token IDs set from SpecialTokens: {:?}", tool_call_start_ids ); } else { - log_info!("Tool call start token IDs not set (no reliable start token)"); + log_info!("Tool call start token IDs not set (no tool start tokens found in tokenizer)"); } if !tool_call_end_ids.is_empty() { scheduler.set_tool_call_end_tokens(tool_call_end_ids.clone()); - log_info!("Tool call end token IDs set to: {:?}", tool_call_end_ids); + log_info!("Tool call end token IDs set from SpecialTokens: {:?}", tool_call_end_ids); } else { - log_info!("Tool call end token IDs not set (no reliable end token)"); + log_info!("Tool call end token IDs not set (no tool end tokens found in tokenizer)"); } // Set tokenizer for JSON tool call detection (for models like Qwen3 that output raw JSON) @@ -424,7 +450,7 @@ impl LLMEngine { ); let escaped_special_tokens = ChatTemplate::collect_escape_tokens( &tokenizer, - &[&tool_config.start_token_str, &tool_config.end_token_str], + &[], ); template.set_escape_tokens(escaped_special_tokens); @@ -439,6 +465,39 @@ impl LLMEngine { "default".to_string() }; + // Create tool config from special tokens for backward compatibility + // Use idiomatic tool_tokens() method that returns Option<(SpecialToken, SpecialToken)> + let tool_config = match special_tokens.tool_tokens() { + Some((start_token, end_token)) => { + log_info!( + "Tool tokens extracted from SpecialTokens: start={:?} ({:?}), end={:?} ({:?})", + start_token.id, + start_token.string(), + end_token.id, + end_token.string() + ); + ToolConfig { + start_token_ids: HashSet::from_iter(vec![start_token.id]), + end_token_ids: HashSet::from_iter(vec![end_token.id]), + start_token_str: start_token.string(), + end_token_str: end_token.string(), + start_is_special: false, + end_is_special: false, + } + } + None => { + log_info!("Tool tokens not found in SpecialTokens, falling back to empty config"); + ToolConfig { + start_token_ids: HashSet::new(), + end_token_ids: HashSet::new(), + start_token_str: "".to_string(), + end_token_str: "".to_string(), + start_is_special: false, + end_is_special: false, + } + } + }; + let engine = Arc::new(RwLock::new(Self { runners, scheduler, @@ -461,6 +520,7 @@ impl LLMEngine { tool_config, img_cfg, model_name, + special_tokens, })); Self::start_engine(engine.clone()); Ok(engine) @@ -517,31 +577,14 @@ impl LLMEngine { } if let Some(stop_sequences) = ¶ms.stop_sequences { - let mut stop_token_ids = Vec::new(); let mut resolved_stop_sequences = Vec::new(); for sequence in stop_sequences { if sequence.is_empty() { continue; } - match self.tokenizer.encode(sequence.as_str(), false) { - Ok(encoding) => { - let ids = encoding.get_ids(); - if !ids.is_empty() { - stop_token_ids.push(ids.to_vec()); - resolved_stop_sequences.push(sequence.clone()); - } - } - Err(err) => { - crate::log_warn!( - "Failed to encode stop sequence '{}': {:?}", - sequence, - err - ); - } - } + resolved_stop_sequences.push(sequence.clone()); } - if !stop_token_ids.is_empty() { - params.stop_token_ids = Some(stop_token_ids); + if !resolved_stop_sequences.is_empty() { params.stop_sequences = Some(resolved_stop_sequences); } } @@ -1576,4 +1619,9 @@ impl LLMEngine { pub fn get_chat_template(&self) -> ChatTemplate { self.template.clone() } + + pub fn template_supports_tools(&self) -> bool { + self.template.supports_tools() + } + } diff --git a/src/core/mod.rs b/src/core/mod.rs index 95563506..d4abf4ef 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -78,6 +78,42 @@ macro_rules! log_warn { }; } +#[macro_export] +macro_rules! log_debug { + ($($arg:tt)*) => { + { + #[cfg(feature = "python")] + { + use colored::Colorize; + let s = format!($($arg)*); + println!("{}", String::from(s).truecolor(100, 100, 100)); + } + #[cfg(not(feature = "python"))] + { + tracing::debug!($($arg)*); + } + } + }; +} + +#[macro_export] +macro_rules! log_trace { + ($($arg:tt)*) => { + { + #[cfg(feature = "python")] + { + use colored::Colorize; + let s = format!($($arg)*); + println!("{}", String::from(s).truecolor(50, 50, 50)); + } + #[cfg(not(feature = "python"))] + { + tracing::trace!($($arg)*); + } + } + }; +} + #[macro_export] macro_rules! log_error { ($($arg:tt)*) => { diff --git a/src/core/prefix_cache.rs b/src/core/prefix_cache.rs index ea9c41b5..323ef271 100644 --- a/src/core/prefix_cache.rs +++ b/src/core/prefix_cache.rs @@ -323,6 +323,36 @@ impl PrefixCache { self.access_counter } + /// Remove a hash from the cache and update parent/children bookkeeping + /// Returns the removed block_id if found + pub fn remove_hash(&mut self, hash: &u64) -> Option { + let entry = self.entries.remove(hash)?; + let block_id = entry.block_id; + + // Update parent's children count + if let Some(parent_hash) = entry.parent { + if let Some(parent_entry) = self.entries.get_mut(&parent_hash) { + parent_entry.children -= 1; + if parent_entry.children == 0 { + self.leaf_set.insert(parent_hash); + self.touch_leaf(parent_hash); + } + } + } + + // Remove from leaf set + self.leaf_set.remove(&hash); + + Some(block_id) + } + + /// Find the hash associated with a block_id + pub fn hash_for_block(&self, block_id: usize) -> Option { + self.entries.iter() + .find(|(_, entry)| entry.block_id == block_id) + .map(|(hash, _)| *hash) + } + fn hash_block(parent_hash: u64, tokens: &[u32]) -> u64 { let mut hasher = std::collections::hash_map::DefaultHasher::new(); parent_hash.hash(&mut hasher); diff --git a/src/core/runner.rs b/src/core/runner.rs index 2c260121..92f41c4e 100644 --- a/src/core/runner.rs +++ b/src/core/runner.rs @@ -8,7 +8,9 @@ use crate::transfer::Transfer; use crate::utils::graph::{ planned_graph_capture_batches, CudaGraphFn, CudaGraphWrapper, GraphCapturer, ModelFn, }; -use crate::utils::guidance::GuidanceState; +use crate::utils::guidance::{GuidanceState, ParserFactory}; +// use crate::utils::guidance::{GuidanceState, ParserFactory, batch_mask_bias, early_exit_validate}; +use toktrie::SimpleVob; use crate::utils::image::compute_image_slice; use crate::utils::logits_processor::{LogitsProcessor, Sampling}; use crate::utils::progress::ProgressLike; @@ -35,10 +37,9 @@ use attention_rs::InputMetadata; use candle_core::{DType, Device, Result, Tensor, D}; use interprocess::local_socket::Stream as LocalStream; use parking_lot::RwLock; -use std::collections::{HashMap, HashSet}; +use std::collections::{hash_map::Entry, HashMap, HashSet}; use std::rc::Rc; use std::sync::{Arc, Mutex, MutexGuard}; -use toktrie::TokTrie; /// Cached sampling parameters computed once during prefill, reused during decode #[derive(Clone, Debug)] @@ -94,6 +95,9 @@ pub struct ModelRunner { seq_tokens: RwLock>>, restored_prefix_sequences: RwLock>, guidance_states: RwLock>, + guidance_failed: RwLock>, + guidance_mismatch: RwLock>, + llg_factory: Option>, transfer: Option>, /// Whether this runner is on the first rank (for logging) is_first_rank: bool, @@ -163,7 +167,7 @@ impl ModelRunner { device: Device, reporter: Arc>>, transfer: Option>, - toktrie: Option>, + llg_factory: Option>, stream: Option, ) -> Result { let model = crate::build_model!( @@ -417,12 +421,15 @@ impl ModelRunner { cached_sampling: RwLock::new(None), seq_tokens: RwLock::new(HashMap::new()), restored_prefix_sequences: RwLock::new(HashSet::new()), - guidance_states: RwLock::new(HashMap::new()), - transfer, - is_first_rank: comm.rank() == 0, - model_type, - }) - } + guidance_states: RwLock::new(HashMap::new()), + guidance_failed: RwLock::new(HashSet::new()), + guidance_mismatch: RwLock::new(HashSet::new()), + llg_factory, + transfer, + is_first_rank: comm.rank() == 0, + model_type, + }) + } pub fn get_kv_cache(&self) -> MutexGuard<'_, Vec<(Tensor, Tensor)>> { loop { @@ -1215,10 +1222,205 @@ impl ModelRunner { logits.to_owned() }; - let tokens = self + let logits = if let Some(factory) = &self.llg_factory { + let mut guidance_states = self.guidance_states.write(); + let mut guidance_failed = self.guidance_failed.write(); + let mut guidance_mismatch = self.guidance_mismatch.write(); + let mut modified = false; + let vocab_size = logits.dim(1)?; + + // We only materialize logits on CPU if at least one constraint mask applies. + + // We'll collect masks first to minimize holding locks or complex logic inside the loop + let mut masks: Vec<(usize, usize, SimpleVob)> = Vec::new(); // (seq_index, seq_id, mask) + + for (i, id) in seq_ids.iter().enumerate() { + let sampling_params = match &seqs { + Seqs::SeqRefs(refs) => &refs[i].sampling_params, + Seqs::DecodeVec(vec) => &vec[i].sampling_params, + }; + + if guidance_failed.contains(id) { + continue; + } + + // Use grammar directly from sampling_params + let grammar = match sampling_params.grammar.as_ref() { + Some(g) => g, + None => continue, + }; + + let state = match guidance_states.entry(*id) { + Entry::Occupied(entry) => entry.into_mut(), + Entry::Vacant(entry) => { + match GuidanceState::new_from_grammar(factory.clone(), grammar) { + Ok(state) => entry.insert(state), + Err(err) => { + guidance_failed.insert(*id); + crate::log_warn!( + "[Seq {}] Failed to create guidance state: {}. Disabling constraints for this sequence.", + id, + err + ); + continue; + } + } + } + }; + + if let Ok(Some(mask)) = state.compute_mask() { + masks.push((i, *id, mask)); + modified = true; + } + } + + if modified { + // Now we must convert to Vec, modify, and update logits + let mut logits_vec = logits.flatten_all()?.to_vec1::()?; + + for (seq_idx, seq_id, mask) in masks { + let start = seq_idx * vocab_size; + let end = start + vocab_size; + let row = &mut logits_vec[start..end]; + let mask_len = mask.len(); + + // Apply mask: set disallowed to -inf + // This iterates entire vocab, but check is fast + if mask_len == 0 { + if guidance_failed.insert(seq_id) { + crate::log_warn!( + "[Seq {}] Guidance mask length is 0. Disabling constraints for this sequence.", + seq_id + ); + } + continue; + } + + if mask_len != vocab_size && guidance_mismatch.insert(seq_id) { + crate::log_warn!( + "[Seq {}] Guidance mask size {} does not match vocab size {}. Clamping mask application.", + seq_id, + mask_len, + vocab_size + ); + // Snapshot is captured when constraint is first applied in GuidanceState::new() + // Rollback is handled via Matcher::rollback() in GuidanceState::rollback_to() + } + + let apply_len = std::cmp::min(vocab_size, mask_len); + for tok in 0..apply_len { + if !mask.is_allowed(tok as u32) { + row[tok] = f32::NEG_INFINITY; + } + } + if mask_len < vocab_size { + for tok in mask_len..vocab_size { + row[tok] = f32::NEG_INFINITY; + } + } + } + Tensor::from_vec(logits_vec, logits.shape(), &self.device)? + /* + // Use optimized batch mask bias function + batch_mask_bias( + &logits, + &masks.iter().map(|(seq_idx, _, mask)| (*seq_idx, mask.clone())).collect::>(), + vocab_size, + )? + */ + } else { + logits + } + + } else { + logits + }; + + let mut tokens = self .logit_processor .sample_with_strategy(&logits, &cached_params.sampling)?; + // Re-sample tokens that fail validation (hybrid approach) + if let Some(_factory) = &self.llg_factory { + let mut guidance_states = self.guidance_states.write(); + for (seq_idx, seq_id) in seq_ids.iter().enumerate() { + let token = tokens[seq_idx]; + + crate::log_trace!("[llg] Processing seq {} (idx {}): token {}", seq_id, seq_idx, token); + + if let Some(state) = guidance_states.get_mut(seq_id) { + if state.is_finished() { + crate::log_trace!("[llg] Matcher is stopped for seq {}, skipping validation", seq_id); + continue; + } + + let valid = state.validate_token(token); + crate::log_trace!("[llg] Token {} validation result: {}", token, valid); + + if valid { + crate::log_trace!("[llg] Token {} is valid, consuming for seq {}", token, seq_id); + let _ = state.commit_token(token); + } else { + crate::log_debug!("[llg] Token {} is invalid, computing mask for seq {}", token, seq_id); + let mask = match state.compute_mask_or_eos() { + Ok(m) => m, + Err(e) => { + crate::log_error!( + "[llg] Unable to compute mask for token {} due to {}", token, e + ); + continue; + } + }; + + crate::log_debug!("[llg] Applying bias to logits for seq {}", seq_id); + + // Memory-efficient: use flat vector with slice operations + let vocab_size = logits.dim(1)?; + let row_start = seq_idx * vocab_size; + let row_end = row_start + vocab_size; + + let mut row_vec = logits.clone().flatten_all()?.to_vec1::()?; + let row = &mut row_vec[row_start..row_end]; + + // Direct mask application: set disallowed tokens to -inf + for tok in 0..vocab_size { + if !mask.is_allowed(tok as u32) { + row[tok] = f32::NEG_INFINITY; + } + } + + // Create tensor with correct shape for re-sampling + let biased_tensor = Tensor::from_vec(row_vec, logits.shape(), logits.device())?; + + crate::log_debug!("[llg] Re-sampling with biased logits for seq {}", seq_id); + + // Use sample_with_strategy with proper cached params + let re_sampled = self.logit_processor.sample_with_strategy(&biased_tensor, &cached_params.sampling)?; + tokens[seq_idx] = re_sampled[seq_idx]; + + crate::log_debug!("[llg] Consuming re-sampled token {} for seq {}", tokens[seq_idx], seq_id); + let _ = state.commit_token(tokens[seq_idx]); + } + } else { + crate::log_debug!("[llg] No guidance state for seq {}", seq_id); + } + /* + // Use optimized early exit validation + let vocab_size = logits.dim(1)?; + early_exit_validate( + &mut guidance_states, + &seq_ids, + &mut tokens, + &logits, + vocab_size, + factory, + &cached_params.sampling, + &self.logit_processor, + )? + */ + } + } + // Track tokens for sequences when penalties are enabled if has_any_penalty { let mut seq_tokens = self.seq_tokens.write(); @@ -1233,6 +1435,8 @@ impl ModelRunner { } } } + + // Token commits are now done inline in the re-sample loop below Ok(tokens) } @@ -1411,19 +1615,55 @@ impl ModelRunner { pub fn clear_blocks(&self, _block_ids: Vec) -> Result { Ok(true) - // fn cache_clear(gpu_cache: &Vec<(Tensor, Tensor)>, block_ids: &Vec) -> Result { - // if gpu_cache.is_empty() || block_ids.is_empty() { - // return Ok(true); - // } + } - // for i in 0..gpu_cache.len() { - // cache::clear_blocks(&gpu_cache[i].0, block_ids)?; - // cache::clear_blocks(&gpu_cache[i].1, block_ids)?; - // } + /// Validate a sequence's output_ids against the grammar using llguidance + /// Returns Some(valid_token_count) if guidance exists, None if no constraint + pub fn validate_sequence_for_grammar(&self, seq_id: usize, output_ids: &[u32]) -> Option { + let mut guidance_states = self.guidance_states.write(); + let state = guidance_states.get_mut(&seq_id)?; + match state.validate_tokens(output_ids) { + Some(count) => Some(count), + None => None, + } + } - // Ok(true) - // } + /// Rollback guidance state for a sequence + /// This is called from Scheduler::rollback_sequence() to reset llguidance FSM state + pub fn rollback_sequence_for_guidance(&self, seq_id: usize, target_tokens: usize) -> Result<()> { + let mut guidance_states = self.guidance_states.write(); + let mut guidance_failed = self.guidance_failed.write(); + let mut guidance_mismatch = self.guidance_mismatch.write(); + + if let Some(state) = guidance_states.get_mut(&seq_id) { + // Calculate byte position (approx 4 bytes per token) + let target_bytes = target_tokens * 4; + match state.rollback_to(target_tokens, target_bytes) { + Ok(()) => {} + Err(e) => { + return Err(candle_core::Error::Msg(format!("Guidance rollback failed: {}", e))); + } + } + } + + // Clear failed and mismatch status for re-initialization + guidance_failed.remove(&seq_id); + guidance_mismatch.remove(&seq_id); + + Ok(()) + } - // cache_clear(&*self.get_kv_cache(), &block_ids) + /// Fast-forward and consume tokens guaranteed to be accepted by the grammar + /// This is used for speculative decoding optimization + pub fn consume_ff_tokens(&self, seq_id: usize) -> Result> { + let mut guidance_states = self.guidance_states.write(); + if let Some(state) = guidance_states.get_mut(&seq_id) { + match state.consume_ff_tokens() { + Ok(tokens) => Ok(tokens), + Err(e) => Err(candle_core::Error::Msg(format!("FF tokens failed: {}", e))), + } + } else { + Ok(Vec::new()) + } } } diff --git a/src/core/scheduler.rs b/src/core/scheduler.rs index 9c786d4a..62d7e5af 100644 --- a/src/core/scheduler.rs +++ b/src/core/scheduler.rs @@ -5,11 +5,12 @@ use super::{ prefix_cache::PrefixCacheConfig, sequence::{Sequence, SequenceStatus}, }; +use crate::tools::parser::prefix_could_be_tool; use crate::transfer::{PdConfig, PdRole}; -use crate::utils::config::{Config, EngineConfig, EosTokenId}; +use crate::utils::config::{Config, EngineConfig}; +use crate::utils::special_tokens::SpecialTokens; use candle_core::Result; use parking_lot::RwLock; -use regex::Regex; use std::collections::VecDeque; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; @@ -21,17 +22,16 @@ pub struct Scheduler { transferred: VecDeque, pub block_manager: BlockManager, next_seq_id: usize, - eos_token_id: Vec, + /// Token IDs that represent the end of sequence + pub eos_token_id: Vec, /// Token IDs that represent the end of a tool call (e.g., tokens) tool_call_end_token_ids: Vec, /// Token IDs that represent the start of a tool call (used to avoid false end matches) tool_call_start_token_ids: Vec, /// Token ID for } character (used for JSON tool call detection) - json_end_token_id: Option, + json_end_token_ids: Vec, /// Tokenizer for decoding output to check JSON tool call patterns tokenizer: Option>, - /// Regex for detecting JSON tool calls - tool_call_regex: Regex, cfg: EngineConfig, pd_config: Option, is_last_prefill: bool, @@ -112,7 +112,13 @@ fn build_prefix_cache_config(econfig: &EngineConfig) -> PrefixCacheConfig { } impl Scheduler { - pub fn new(runners: Arc>, econfig: &EngineConfig, config: &Config) -> Self { + /// Create a new Scheduler with SpecialTokens for EOS detection + pub fn new( + runners: Arc>, + econfig: &EngineConfig, + config: &Config, + special_tokens: Arc, + ) -> Self { let prefix_cache_cfg = build_prefix_cache_config(econfig); Self { waiting: VecDeque::new(), @@ -133,19 +139,13 @@ impl Scheduler { .unwrap_or(false), ), next_seq_id: 0, - eos_token_id: match &config.eos_token_id { - Some(EosTokenId::Single(eos)) => vec![*eos], - Some(EosTokenId::Multiple(eos)) => eos.into_iter().map(|x| *x).collect(), - _ => vec![], - }, + // Use SpecialTokens for EOS token IDs - this is the single source of truth + eos_token_id: special_tokens.eos_ids(), // Tool call end tokens will be set by engine after tokenizer is initialized tool_call_end_token_ids: Vec::new(), tool_call_start_token_ids: Vec::new(), - json_end_token_id: None, + json_end_token_ids: Vec::new(), tokenizer: None, - // Regex to match JSON tool call format: {"name": "...", "arguments": {...}} - // We use (?s) to allow dot matching newlines - tool_call_regex: Regex::new(r#"(?s)\{\s*"name"\s*:.*"arguments"\s*:.*\}\s*$"#).unwrap(), cfg: econfig.clone(), pd_config: econfig.pd_config.clone(), is_last_prefill: false, @@ -164,13 +164,21 @@ impl Scheduler { /// Set tokenizer for JSON tool call detection (called by engine after initialization) pub fn set_tokenizer(&mut self, tokenizer: Arc) { - // Get the token ID for "}" character - if let Ok(tokens) = tokenizer.encode("}", false) { - if let Some(&token_id) = tokens.get_ids().last() { - self.json_end_token_id = Some(token_id); - crate::log_info!("JSON end token ID (}}) set to: {}", token_id); + self.json_end_token_ids.clear(); + + for ch in ["}", "]"] { + if let Ok(tokens) = tokenizer.encode(ch, false) { + if let Some(&token_id) = tokens.get_ids().last() { + if !self.json_end_token_ids.contains(&token_id) { + self.json_end_token_ids.push(token_id); + } + } } } + + if !self.json_end_token_ids.is_empty() { + crate::log_info!("JSON end token IDs set to: {:?}", self.json_end_token_ids); + } self.tokenizer = Some(tokenizer); } @@ -182,6 +190,75 @@ impl Scheduler { id } + /// Check if the sequence has grammar validation failures + /// Uses ModelRunner::validate_sequence_for_grammar() to validate the entire output_ids sequence + /// Returns true if validation failed and rollback is needed + fn should_rollback_for_grammar(&mut self, seq_id: usize, output_ids: &[u32]) -> bool { + let runners = self.block_manager.get_runners(); + let runners_guard = runners.read(); + + if let RunnerType::Thread(model_runner) = &*runners_guard { + if let Some(valid_count) = model_runner.validate_sequence_for_grammar(seq_id, output_ids) { + return valid_count < output_ids.len(); + } + } + + false + } + + /// Rollback a sequence to a specific token position + /// This is called from postprocess() when grammar validation fails + /// The sequence is truncated and cache states are rolled back + pub fn rollback_sequence(&mut self, seq_id: usize, target_tokens: usize) -> Result<()> { + const MAX_ROLLBACK_ATTEMPTS: usize = 3; + + // Find the sequence + let seq = self.running.iter_mut() + .find(|s| s.id == seq_id) + .ok_or_else(|| candle_core::Error::msg(format!("Sequence {} not found", seq_id)))?; + + seq.guidance_rollback_count += 1; + + if seq.guidance_rollback_count > MAX_ROLLBACK_ATTEMPTS { + crate::log_error!( + "[Seq {}] Exceeded {} rollback attempts, marking as errored", + seq_id, MAX_ROLLBACK_ATTEMPTS + ); + seq.status = SequenceStatus::Finished; + return Ok(()); + } + + // Save current state as rollback snapshot (if not already saved) + if seq.rollback_snapshot.is_none() { + seq.save_rollback_snapshot(); + } + + // Get target block count + let target_blocks = target_tokens.div_ceil(self.block_manager.get_block_size()); + + // Truncate Sequence state + seq.token_ids.truncate(target_tokens); + seq.block_table.truncate(target_blocks); + seq.num_cached_tokens = target_blocks * self.block_manager.get_block_size(); + + // Rollback BlockManager (KV cache + prefix cache) + self.block_manager.rollback_to_seq_tokens(seq, target_tokens)?; + + // Rollback ModelRunner (llguidance FSM + Mamba state) + let runners = self.block_manager.get_runners().clone(); + { + let runners_guard = runners.read(); + if let RunnerType::Thread(model_runner) = &*runners_guard { + model_runner.rollback_sequence_for_guidance(seq_id, target_tokens)?; + } + } + + // Update sequence status for reprocessing + seq.status = SequenceStatus::Running; + + Ok(()) + } + pub fn is_finished(&self) -> bool { self.waiting.is_empty() && self.running.is_empty() } @@ -535,6 +612,38 @@ impl Scheduler { } } + // Check for grammar validation failures using llguidance + // Validate the entire output_ids sequence + let seq = &self.running[idx]; + let seq_id = seq.id; + let output_ids = seq.output_ids.clone(); + + if self.should_rollback_for_grammar(seq_id, &output_ids) { + let target_tokens = output_ids.len(); + let target_blocks = target_tokens.div_ceil(self.block_manager.get_block_size()); + let target_tokens_aligned = target_blocks * self.block_manager.get_block_size(); + + crate::log_info!( + "[Seq {}] Grammar validation failed, rolling back to {} tokens ({} blocks)", + seq_id, + target_tokens_aligned, + target_blocks + ); + + // Trigger rollback + if let Err(e) = self.rollback_sequence(seq_id, target_tokens_aligned) { + crate::log_error!( + "[Seq {}] Rollback failed: {}. Finishing sequence.", + seq_id, + e + ); + let seq = &mut self.running[idx]; + seq.status = SequenceStatus::Finished; + self.block_manager.deallocate(seq); + } + continue; + } + let matched_stop_sequence_idx = self.stop_sequence_match_index(token, &self.running[idx]); let hit_stop_sequence = matched_stop_sequence_idx.is_some(); @@ -1134,7 +1243,7 @@ impl Scheduler { /// Check if the given token is a tool call end token /// This supports both: /// 1. Explicit tool call end tokens (e.g., in XML format) - /// 2. JSON end token "}" combined with Regex validation for {..."name":..., "arguments":...} pattern + /// 2. JSON end token "}" combined with prefix_could_be_tool validation pub fn is_tool_call_end(&self, token: u32, idx: usize) -> bool { // 1. Check for explicit tool call end tokens (XML style) if self.tool_call_end_token_ids.contains(&token) { @@ -1152,19 +1261,24 @@ impl Scheduler { return true; } - // 2. Check for JSON style tool call using Regex - // This handles models like Qwen3 that output raw JSON without XML tags - if self.json_end_token_id == Some(token) { + // 2. Check for JSON style tool call by attempting to parse complete JSON + if self.json_end_token_ids.contains(&token) { if let Some(tokenizer) = &self.tokenizer { // Temporarily add the token to get complete output for decoding let mut temp_output = self.running[idx].output_ids.to_vec(); temp_output.push(token); if let Ok(decoded) = tokenizer.decode(&temp_output, true) { - // Check for JSON tool call pattern using Regex - // The pattern matches if the decoded string ends with a valid JSON tool call structure - if self.tool_call_regex.is_match(&decoded) { - return true; + let trimmed = decoded.trim(); + if let Ok(val) = serde_json::from_str::(trimmed) { + if val.is_object() || val.is_array() { + return true; + } + } else { + let (_could_be, is_complete) = prefix_could_be_tool(trimmed); + if is_complete { + return true; + } } } } @@ -1173,37 +1287,15 @@ impl Scheduler { false } - fn stop_sequence_match_index(&self, token: u32, seq: &Sequence) -> Option { - let Some(stop_sequences) = &seq.sampling_params.stop_token_ids else { - return None; - }; - if stop_sequences.is_empty() { - return None; - } - - for (idx, stop) in stop_sequences.iter().enumerate() { - if stop.is_empty() { - continue; - } - if stop.len() == 1 { - if stop[0] == token { - return Some(idx); - } - continue; - } - - let prior_len = seq.output_ids.len(); - if stop.len() - 1 > prior_len { - continue; - } - let start_idx = prior_len + 1 - stop.len(); - if seq.output_ids[start_idx..] == stop[..stop.len() - 1] - && stop[stop.len() - 1] == token - { - return Some(idx); - } - } + /// Get the EOS token IDs from the scheduler + pub fn eos_token_ids(&self) -> &[u32] { + &self.eos_token_id + } + fn stop_sequence_match_index(&self, _token: u32, seq: &Sequence) -> Option { + // Stop sequence matching now uses SpecialTokens + // The actual matching is done via SpecialTokens.search() + seq.sampling_params.stop_sequences.as_ref()?; None } } diff --git a/src/core/sequence.rs b/src/core/sequence.rs index d096bbe4..00de6cb2 100644 --- a/src/core/sequence.rs +++ b/src/core/sequence.rs @@ -28,6 +28,13 @@ impl fmt::Display for SequenceStatus { } } +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct RollbackSnapshot { + pub block_table: Vec, + pub num_cached_tokens: usize, + pub mamba_prefix_hash: Option, +} + #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Sequence { pub id: usize, @@ -47,6 +54,10 @@ pub struct Sequence { pub is_tool_call_end: bool, pub hit_stop_sequence: bool, pub stop_sequence: Option, + /// Snapshot for rollback on speculative decoding mismatch + pub rollback_snapshot: Option, + /// Rollback counter for guidance constraints to prevent infinite loops + pub guidance_rollback_count: usize, } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -179,6 +190,8 @@ impl Sequence { is_tool_call_end: false, hit_stop_sequence: false, stop_sequence: None, + rollback_snapshot: None, + guidance_rollback_count: 0, } } @@ -235,4 +248,22 @@ impl Sequence { pub fn clear_block_table(&mut self) { self.block_table.clear(); } + + /// Save current state as rollback snapshot + pub fn save_rollback_snapshot(&mut self) { + self.rollback_snapshot = Some(RollbackSnapshot { + block_table: self.block_table.clone(), + num_cached_tokens: self.num_cached_tokens, + mamba_prefix_hash: self.mamba_prefix_hash, + }); + } + + /// Restore from rollback snapshot + pub fn restore_from_snapshot(&mut self) { + if let Some(snapshot) = self.rollback_snapshot.take() { + self.block_table = snapshot.block_table; + self.num_cached_tokens = snapshot.num_cached_tokens; + self.mamba_prefix_hash = snapshot.mamba_prefix_hash; + } + } } diff --git a/src/main.rs b/src/main.rs index b7ac1b1f..61ce7374 100644 --- a/src/main.rs +++ b/src/main.rs @@ -182,36 +182,38 @@ async fn main() -> Result<()> { }; let econfig = EngineConfig::new( - args.model_id, - args.weight_path, - args.weight_file, - args.hf_token, - args.hf_token_path, - args.enforce_parser.clone(), - Some(std::cmp::max(max_num_seqs, prompts.len())), - None, - max_model_len, - Some(args.max_tokens), - args.isq.clone(), - Some(1), - args.device_ids.clone(), - generation_cfg, - args.seed, - Some(prefix_cache), - args.prefix_cache_max_tokens, - Some(args.fp8_kvcache), - Some(args.server || args.ui_server || !interactive), - args.cpu_mem_fold, - args.kv_fraction, - args.mamba_fraction, - pd_config, - args.mcp_command.clone(), - args.mcp_config.clone(), - args.mcp_args.clone(), - tool_prompt_template, - None, // pd_server_prefix_cache_ratio - None, // pd_client_prefix_cache_ratio - ); + args.model_id, + args.weight_path, + args.weight_file, + args.hf_token, + args.hf_token_path, + args.enforce_parser.clone(), + Some(std::cmp::max(max_num_seqs, prompts.len())), + None, + max_model_len, + Some(args.max_tokens), + args.isq.clone(), + Some(1), + args.device_ids.clone(), + generation_cfg, + args.seed, + Some(prefix_cache), + args.prefix_cache_max_tokens, + Some(args.fp8_kvcache), + Some(args.server || args.ui_server || !interactive), + args.cpu_mem_fold, + args.kv_fraction, + args.mamba_fraction, + pd_config, + args.mcp_command.clone(), + args.mcp_config.clone(), + args.mcp_args.clone(), + tool_prompt_template, + None, // pd_server_prefix_cache_ratio + None, // pd_client_prefix_cache_ratio + args.allow_constraint_api, + args.enable_tool_grammar, + ); let engine = LLMEngine::new(&econfig, dtype)?; if args.server || args.ui_server || args.pd_server { diff --git a/src/models/gemma3/config.rs b/src/models/gemma3/config.rs index 548c2e91..f2c7c4e1 100644 --- a/src/models/gemma3/config.rs +++ b/src/models/gemma3/config.rs @@ -1,5 +1,4 @@ use crate::serde_default; -use crate::utils::config::EosTokenId; use crate::utils::config::QuantConfig; use crate::utils::config::RopeScalingValue; use candle_nn::Activation; @@ -100,7 +99,7 @@ pub struct Gemma3Config { pub vision_config: VisionConfig, pub image_token_index: usize, pub mm_tokens_per_image: usize, - pub eos_token_id: Option, + pub eos_token_id: Option>, #[serde(default = "has_vision")] pub has_vision: bool, } diff --git a/src/models/gemma3/mod.rs b/src/models/gemma3/mod.rs index 68e435c2..600aafd6 100644 --- a/src/models/gemma3/mod.rs +++ b/src/models/gemma3/mod.rs @@ -538,11 +538,7 @@ impl Gemma3ForConditionalGeneration { } else { vb.pp("language_model.model.embed_tokens") }, - if is_qvar_builder || g_cfg.quant.is_some() { - DType::F32 - } else { - dtype - }, + dtype, )?; let embed_scale = (config.text_config.hidden_size as f64).sqrt(); @@ -659,6 +655,17 @@ impl Gemma3ForConditionalGeneration { }) } + fn embed_forward(&self, input_ids: &Tensor) -> Result { + let xs = self.embed_tokens.forward(input_ids)?; + let xs = if (self.is_qvar_builder || self.g_cfg.quant.is_some()) && xs.dtype() != DType::F32 + { + xs.to_dtype(DType::F32)? + } else { + xs + }; + xs * self.embed_scale + } + fn vision_tower( &self, image_features: &Tensor, @@ -687,7 +694,7 @@ impl Gemma3ForConditionalGeneration { ) -> Result { let text_cfg = &self.config.text_config; // 1. Prepare Text Embeddings (Scaled) - let mut xs = (self.embed_tokens.forward(input_ids)? * self.embed_scale)?; + let mut xs = self.embed_forward(input_ids)?; // vision projection and embedding if let Some(images) = images { diff --git a/src/models/glm4.rs b/src/models/glm4.rs index bdce6101..33ea6608 100644 --- a/src/models/glm4.rs +++ b/src/models/glm4.rs @@ -206,11 +206,7 @@ impl GLM4ForCausalLM { } else { vb.pp("model.embed_tokens") }, - if is_qvar_builder || config.quant.is_some() { - DType::F32 - } else { - dtype - }, + dtype, )?; let rotary_emb = Arc::new(ScalingRotaryEmbedding::new( if is_qvar_builder || config.quant.is_some() { @@ -293,7 +289,12 @@ impl GLM4ForCausalLM { } pub fn embed_forward(&self, xs: &Tensor) -> Result { - self.embed_tokens.forward(xs) + let xs = self.embed_tokens.forward(xs)?; + if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 { + xs.to_dtype(DType::F32) + } else { + Ok(xs) + } } fn forward_inner( @@ -319,7 +320,7 @@ impl GLM4ForCausalLM { let mut xs = if embeded_inputs { input_ids.to_owned() } else { - self.embed_tokens.forward(input_ids)? + self.embed_forward(input_ids)? }; if let Some(kv_caches) = kv_caches { diff --git a/src/models/glm4_moe.rs b/src/models/glm4_moe.rs index 18c7c75b..b2d1855c 100644 --- a/src/models/glm4_moe.rs +++ b/src/models/glm4_moe.rs @@ -305,11 +305,7 @@ impl GLM4MoEForCausalLM { } else { vb.pp(&format!("{}embed_tokens", prefix)) }, - if is_qvar_builder || config.quant.is_some() { - DType::F32 - } else { - dtype - }, + dtype, )?; let rotary_emb = Arc::new(ScalingRotaryEmbedding::new( if is_qvar_builder || config.quant.is_some() { @@ -393,7 +389,12 @@ impl GLM4MoEForCausalLM { } pub fn embed_forward(&self, xs: &Tensor) -> Result { - self.embed_tokens.forward(xs) + let xs = self.embed_tokens.forward(xs)?; + if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 { + xs.to_dtype(DType::F32) + } else { + Ok(xs) + } } fn forward_inner( @@ -420,7 +421,7 @@ impl GLM4MoEForCausalLM { let mut xs = if embeded_inputs { input_ids.to_owned() } else { - self.embed_tokens.forward(input_ids)? + self.embed_forward(input_ids)? }; if let Some(kv_caches) = kv_caches { diff --git a/src/models/llama.rs b/src/models/llama.rs index 2bf3e408..450512d4 100644 --- a/src/models/llama.rs +++ b/src/models/llama.rs @@ -173,11 +173,7 @@ impl LLaMaForCausalLM { } else { vb.pp("model.embed_tokens").clone() }, - if is_qvar_builder || config.quant.is_some() { - DType::F32 - } else { - dtype - }, + dtype, )?; let rotary_emb = Arc::new(ScalingRotaryEmbedding::new( @@ -262,7 +258,12 @@ impl LLaMaForCausalLM { } pub fn embed_forward(&self, xs: &Tensor) -> Result { - self.embed_tokens.forward(xs) + let xs = self.embed_tokens.forward(xs)?; + if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 { + xs.to_dtype(DType::F32) + } else { + Ok(xs) + } } fn forward_inner( @@ -287,7 +288,7 @@ impl LLaMaForCausalLM { let mut xs = if embeded_inputs { input_ids.to_owned() } else { - self.embed_tokens.forward(input_ids)? + self.embed_forward(input_ids)? }; if let Some(kv_caches) = kv_caches { diff --git a/src/models/phi4.rs b/src/models/phi4.rs index 05fe2759..7444eb08 100644 --- a/src/models/phi4.rs +++ b/src/models/phi4.rs @@ -513,11 +513,7 @@ impl Phi4ForCausalLM { } else { vb.pp("model.embed_tokens") }, - if is_qvar_builder || config.quant.is_some() { - DType::F32 - } else { - dtype - }, + dtype, )?; let rotary_emb = Arc::new(Phi4RotaryEmbedding::new( if is_qvar_builder || config.quant.is_some() { @@ -595,6 +591,15 @@ impl Phi4ForCausalLM { }) } + pub fn embed_forward(&self, xs: &Tensor) -> Result { + let xs = self.embed_tokens.forward(xs)?; + if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 { + xs.to_dtype(DType::F32) + } else { + Ok(xs) + } + } + fn forward_inner( &self, input_ids: &Tensor, @@ -620,7 +625,7 @@ impl Phi4ForCausalLM { let mut xs = if embeded_inputs { input_ids.to_owned() } else { - self.embed_tokens.forward(input_ids)? + self.embed_forward(input_ids)? }; if let Some(kv_caches) = kv_caches { diff --git a/src/models/qwen3.rs b/src/models/qwen3.rs index 8d7f1bb0..6d653a42 100644 --- a/src/models/qwen3.rs +++ b/src/models/qwen3.rs @@ -214,11 +214,7 @@ impl Qwen3ForCausalLM { } else { vb.pp(&format!("{}embed_tokens", prefix)) }, - if is_qvar_builder || config.quant.is_some() { - DType::F32 - } else { - dtype - }, + dtype, )?; let rotary_emb = Arc::new(ScalingRotaryEmbedding::new( if is_qvar_builder || config.quant.is_some() { @@ -301,7 +297,12 @@ impl Qwen3ForCausalLM { } pub fn embed_forward(&self, xs: &Tensor) -> Result { - self.embed_tokens.forward(xs) + let xs = self.embed_tokens.forward(xs)?; + if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 { + xs.to_dtype(DType::F32) + } else { + Ok(xs) + } } fn forward_inner( @@ -328,7 +329,7 @@ impl Qwen3ForCausalLM { let mut xs = if embeded_inputs { input_ids.to_owned() } else { - self.embed_tokens.forward(input_ids)? + self.embed_forward(input_ids)? }; if let Some(kv_caches) = kv_caches { for ((k_cache, v_cache), (i, layer)) in diff --git a/src/models/qwen3_5.rs b/src/models/qwen3_5.rs index d7a339d5..b22dffd1 100644 --- a/src/models/qwen3_5.rs +++ b/src/models/qwen3_5.rs @@ -328,11 +328,7 @@ impl Qwen3_5ForCausalLM { } else { vb.pp(&format!("{}embed_tokens", prefix)) }, - if is_qvar_builder || config.quant.is_some() { - DType::F32 - } else { - dtype - }, + dtype, )?; let rotary_emb = Arc::new(ScalingRotaryEmbedding::new( @@ -475,7 +471,12 @@ impl Qwen3_5ForCausalLM { } pub fn embed_forward(&self, xs: &Tensor) -> Result { - self.embed_tokens.forward(xs) + let xs = self.embed_tokens.forward(xs)?; + if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 { + xs.to_dtype(DType::F32) + } else { + Ok(xs) + } } fn forward_inner( @@ -503,7 +504,7 @@ impl Qwen3_5ForCausalLM { let mut xs = if embeded_inputs { input_ids.to_owned() } else { - self.embed_tokens.forward(input_ids)? + self.embed_forward(input_ids)? }; let mut kv_cache_idx = 0usize; diff --git a/src/models/qwen3_5_moe.rs b/src/models/qwen3_5_moe.rs index a014fb65..529a9caf 100644 --- a/src/models/qwen3_5_moe.rs +++ b/src/models/qwen3_5_moe.rs @@ -441,11 +441,7 @@ impl Qwen3_5MoEForCausalLM { } else { vb.pp(&format!("{}embed_tokens", prefix)) }, - if is_qvar_builder || config.quant.is_some() { - DType::F32 - } else { - dtype - }, + dtype, )?; let rotary_emb = Arc::new(ScalingRotaryEmbedding::new( @@ -586,7 +582,12 @@ impl Qwen3_5MoEForCausalLM { } pub fn embed_forward(&self, xs: &Tensor) -> Result { - self.embed_tokens.forward(xs) + let xs = self.embed_tokens.forward(xs)?; + if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 { + xs.to_dtype(DType::F32) + } else { + Ok(xs) + } } fn forward_inner( @@ -613,7 +614,7 @@ impl Qwen3_5MoEForCausalLM { let mut xs = if embeded_inputs { input_ids.to_owned() } else { - self.embed_tokens.forward(input_ids)? + self.embed_forward(input_ids)? }; let mut kv_cache_idx = 0usize; diff --git a/src/models/qwen3_moe.rs b/src/models/qwen3_moe.rs index 6ad153bc..52e30187 100644 --- a/src/models/qwen3_moe.rs +++ b/src/models/qwen3_moe.rs @@ -348,11 +348,7 @@ impl Qwen3MoEForCausalLM { } else { vb.pp(&format!("{}embed_tokens", prefix)) }, - if is_qvar_builder || config.quant.is_some() { - DType::F32 - } else { - dtype - }, + dtype, )?; let rotary_emb = Arc::new(ScalingRotaryEmbedding::new( if is_qvar_builder || config.quant.is_some() { @@ -436,7 +432,12 @@ impl Qwen3MoEForCausalLM { } pub fn embed_forward(&self, xs: &Tensor) -> Result { - self.embed_tokens.forward(xs) + let xs = self.embed_tokens.forward(xs)?; + if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 { + xs.to_dtype(DType::F32) + } else { + Ok(xs) + } } fn forward_inner( @@ -463,7 +464,7 @@ impl Qwen3MoEForCausalLM { let mut xs = if embeded_inputs { input_ids.to_owned() } else { - self.embed_tokens.forward(input_ids)? + self.embed_forward(input_ids)? }; if let Some(kv_caches) = kv_caches { diff --git a/src/py/mod.rs b/src/py/mod.rs index 54d59061..2303c1bc 100644 --- a/src/py/mod.rs +++ b/src/py/mod.rs @@ -7,6 +7,7 @@ use crate::transfer::{PdConfig, PdMethod, PdRole}; use crate::utils::chat_template::Message; use crate::utils::config::{EngineConfig, GenerationConfig, SamplingParams}; use crate::utils::get_dtype; +use llguidance::api::TopLevelGrammar; use parking_lot::RwLock; use pyo3::exceptions::PyStopIteration; use pyo3::exceptions::PyValueError; @@ -268,7 +269,8 @@ impl EngineConfig { fp8_kvcache=None, server_mode=None, cpu_mem_fold=None, kv_fraction=None, mamba_fraction=None, pd_config=None, mcp_command=None, mcp_config=None, mcp_args=None, tool_prompt_template=None, - pd_server_prefix_cache_ratio=None, pd_client_prefix_cache_ratio=None))] + pd_server_prefix_cache_ratio=None, pd_client_prefix_cache_ratio=None, + allow_constraint_api=false, enable_tool_grammar=false))] pub fn new( model_id: Option, weight_path: Option, @@ -299,6 +301,8 @@ impl EngineConfig { tool_prompt_template: Option, pd_server_prefix_cache_ratio: Option, pd_client_prefix_cache_ratio: Option, + allow_constraint_api: bool, + enable_tool_grammar: bool, ) -> Self { let mut device_ids = device_ids.unwrap_or_default(); if device_ids.is_empty() { @@ -342,6 +346,8 @@ impl EngineConfig { tool_prompt_template, pd_server_prefix_cache_ratio, pd_client_prefix_cache_ratio, + allow_constraint_api, + enable_tool_grammar, } } } @@ -351,7 +357,8 @@ impl SamplingParams { #[new] #[pyo3(signature = (temperature=None, max_tokens=None, ignore_eos=Some(false), top_k=None, top_p=None, session_id=None, - frequency_penalty=None, presence_penalty=None, thinking=None))] + frequency_penalty=None, presence_penalty=None, thinking=None, + grammar_json=None))] pub fn new( temperature: Option, max_tokens: Option, @@ -362,7 +369,13 @@ impl SamplingParams { frequency_penalty: Option, presence_penalty: Option, thinking: Option, + grammar_json: Option, ) -> Self { + // Convert grammar_json to TopLevelGrammar if present + let grammar = grammar_json.as_ref().and_then(|s| { + serde_json::from_str::(s).ok() + }); + Self { temperature, max_tokens, @@ -374,8 +387,9 @@ impl SamplingParams { presence_penalty, mcp_mode: None, stop_sequences: None, - stop_token_ids: None, thinking, + grammar_json, + grammar, } } @@ -392,8 +406,25 @@ impl SamplingParams { presence_penalty: None, mcp_mode: None, stop_sequences: None, - stop_token_ids: None, thinking: None, + grammar_json: None, + grammar: None, + } + } + + #[getter] + fn grammar_json(&self) -> Option { + self.grammar.as_ref().and_then(|g| serde_json::to_string(g).ok()) + } + + #[setter] + fn set_grammar_json(&mut self, value: Option) { + self.grammar_json = value.clone(); + // Also update grammar from JSON if provided + if let Some(ref s) = value { + self.grammar = serde_json::from_str::(s).ok(); + } else { + self.grammar = None; } } } diff --git a/src/runner/mod.rs b/src/runner/mod.rs index a22e2445..6d0497dd 100644 --- a/src/runner/mod.rs +++ b/src/runner/mod.rs @@ -10,6 +10,7 @@ use interprocess::local_socket::Stream as LocalStream; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::io::{Read, Write}; +use rmp_serde; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct RunnerInitRequest { pub rank: usize, @@ -253,7 +254,7 @@ pub fn send_local( let serialized = if use_json { serde_json::to_vec(message).expect("JSON serialization failed") } else { - bincode::serialize(message).expect("Bincode serialization failed") + rmp_serde::to_vec(message).expect("RMP serialization failed") }; for stream in streams.iter_mut() { @@ -285,7 +286,7 @@ pub fn receive_local(stream: &mut LocalStream, use_json: bool) -> std::io::Resul let message: MessageType = if use_json { serde_json::from_slice(&serialized).expect("JSON deserialization failed") } else { - bincode::deserialize(&serialized).expect("Bincode deserialization failed") + rmp_serde::from_slice(&serialized).expect("RMP deserialization failed") }; // Send acknowledgment diff --git a/src/runner/runner.rs b/src/runner/runner.rs index ee3d8bed..f6c3b1c3 100644 --- a/src/runner/runner.rs +++ b/src/runner/runner.rs @@ -7,13 +7,14 @@ use std::io::Write; use std::rc::Rc; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use tokenizers::Tokenizer; use vllm_rs::core::runner::{ModelRunner, Seqs}; use vllm_rs::models::layers::distributed::Comm; use vllm_rs::models::layers::VarBuilderX; use vllm_rs::runner::{receive_local, send_local, MessageType}; use vllm_rs::transfer::PdRole; use vllm_rs::transfer::Transfer; -use vllm_rs::utils::guidance::load_toktrie_from_path; +use vllm_rs::utils::guidance::build_llg_factory; use vllm_rs::utils::heartbeat::heartbeat_worker; use vllm_rs::utils::new_device; use vllm_rs::utils::progress::{ProgressLike, ProgressReporter, RemoteProgressReporter}; @@ -126,31 +127,43 @@ fn main() -> anyhow::Result<()> { (None, false) }; - let vb = VarBuilderX::new( - &init_req.model_pathes, - init_req.is_gguf, - init_req.dtype.into(), - &device, - )?; let stream_kv = Some(stream.try_clone()?); let mut econfig = init_req.econfig.clone(); - let toktrie = load_toktrie_from_path(&init_req.model_pathes.get_tokenizer_filename()) - .map(Arc::new); + let tokenizer = Tokenizer::from_file(init_req.model_pathes.get_tokenizer_filename()) + .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?; + let llg_factory = match build_llg_factory(tokenizer, init_req.config.vocab_size) { + Ok(f) => Some(f), + Err(e) => { + vllm_rs::log_warn!("Failed to build llguidance factory: {}", e); + None + } + }; #[allow(unused_mut)] - let mut runner = ModelRunner::new( - init_req.model_type, - &vb, - comm, - &mut econfig, - &init_req.config, - init_req.dtype.into(), - init_req.is_rope_i, - device, - progress_reporter, - transfer, - toktrie, - stream_kv, - )?; + let mut runner = { + let _guard = candle_core::InferenceMode::enter(); + let vb = VarBuilderX::new( + &init_req.model_pathes, + init_req.is_gguf, + init_req.dtype.into(), + &device, + )?; + let runner = ModelRunner::new( + init_req.model_type, + &vb, + comm, + &mut econfig, + &init_req.config, + init_req.dtype.into(), + init_req.is_rope_i, + device, + progress_reporter, + transfer, + llg_factory, + stream_kv, + )?; + drop(vb); + runner + }; vllm_rs::log_info!( "Runner at rank {} created (PD config: {:?})!", diff --git a/src/server/mod.rs b/src/server/mod.rs index 46732b30..b55dc18d 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,3 +1,4 @@ +// src/server/mod.rs use clap::Parser; use serde::{Deserialize, Serialize}; pub mod claude_server; @@ -10,6 +11,7 @@ use crate::server::streaming::Streamer; use crate::transfer::PdRole; use crate::utils::chat_template::Message; use crate::utils::config::EngineConfig; +use crate::utils::guidance::TopLevelGrammarExt; use crate::utils::image::{ compute_tokens_per_image, get_tensor_raw_data, load_image_from_base64, load_image_from_url, ImageData, ImageProcessConfig, ImageProcessTrait, IMAGE_PLACEHOLDER, @@ -26,6 +28,7 @@ use parking_lot::RwLock; use rustchatui::start_ui_server; use serde_json::json; use std::collections::HashMap; +use crate::tools::schema::{schema_to_tools, ToolGrammarBuilder}; use std::path::Path; use std::sync::Arc; use tower_http::cors::{Any, CorsLayer}; @@ -42,6 +45,8 @@ pub struct ChatCompletionRequest { pub presence_penalty: Option, #[serde(alias = "enable_thinking")] pub thinking: Option, + #[serde(default, alias = "stop_sequences")] + pub stop: Option>, pub stream: Option, pub session_id: Option, /// Tools available for the model to call @@ -50,6 +55,22 @@ pub struct ChatCompletionRequest { /// How the model should choose which tool to call #[serde(default)] pub tool_choice: Option, + /// OpenAI-style response format for structured outputs + #[serde(default)] + pub response_format: Option, + /// Extra body for OpenAI-compatible clients (e.g. structured_outputs) + #[serde(default)] + pub extra_body: Option, + /// Direct structured_outputs for convenience (parsed from extra_body if not present) + #[serde(default, alias = "structured_outputs")] + pub structured_outputs: Option, + /// Legacy constraint field for llguidance (llg-new.diff pattern) + /// Use constraint_type to specify grammar format: "regex", "lark", "json_schema" + #[serde(alias = "grammar", default)] + pub constraint: Option, + /// Type of constraint for legacy constraint field + #[serde(default)] + pub constraint_type: Option, } pub fn resolve_engine_model_id(econfig: &EngineConfig) -> Option { @@ -100,6 +121,163 @@ impl Default for EncodingFormat { } } +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub struct StructuredOutputs { + #[serde(default)] + pub choice: Option>, + #[serde(default)] + pub regex: Option, + #[serde(default)] + pub json: Option, + #[serde(default)] + pub grammar: Option, + #[serde(default)] + pub structural_tag: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub struct ResponseFormatJsonSchema { + #[serde(default)] + pub name: Option, + pub schema: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub struct ResponseFormat { + #[serde(rename = "type")] + pub format_type: String, + #[serde(default)] + pub json_schema: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub struct ExtraBody { + #[serde(default)] + pub structured_outputs: Option, + #[serde(flatten)] + pub extra: HashMap, +} + +// TopLevelGrammar conversion functions +// Client grammars are composed via merge_top_level_grammars alongside TEXT and tool grammars. + +pub fn grammar_fragment_from_structured_outputs(structured: &StructuredOutputs) -> Result> { + crate::log_debug!("[llg] grammar_fragment_from_structured_outputs() called"); + + let mut selected: Option = None; + let mut constraint_count = 0; + + if let Some(choice) = &structured.choice { + if !choice.is_empty() { + constraint_count += 1; + if constraint_count > 1 { + crate::log_error!("[llg] Multiple constraints specified - structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag"); + return Err(candle_core::Error::msg("structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag")); + } + crate::log_debug!("[llg] Building choice grammar from: {:?}", choice); + let choice_gram = crate::tools::schema::build_choice_lark_grammar(choice) + .map_err(|e| candle_core::Error::msg(e))?; + selected = Some(choice_gram); + } + } + + if let Some(regex) = &structured.regex { + constraint_count += 1; + if constraint_count > 1 { + crate::log_error!("[llg] Multiple constraints specified - structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag"); + return Err(candle_core::Error::msg("structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag")); + } + crate::log_debug!("[llg] Building regex grammar: {}", regex); + let regex_gram = TopLevelGrammarExt::from_regex_ascii(regex); + selected = Some(regex_gram); + } + + if let Some(schema) = &structured.json { + constraint_count += 1; + if constraint_count > 1 { + crate::log_error!("[llg] Multiple constraints specified - structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag"); + return Err(candle_core::Error::msg("structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag")); + } + crate::log_debug!("[llg] Building JSON schema grammar"); + let schema = crate::tools::schema::sanitize_schema_for_llguidance(schema); + let json_gram = TopLevelGrammarExt::from_json_schema_utf8(schema) + .map_err(|e| candle_core::Error::msg(e.to_string()))?; + selected = Some(json_gram); + } + + if let Some(grammar) = &structured.grammar { + constraint_count += 1; + if constraint_count > 1 { + crate::log_error!("[llg] Multiple constraints specified - structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag"); + return Err(candle_core::Error::msg("structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag")); + } + crate::log_debug!("[llg] Using Lark grammar from structured_outputs.grammar"); + let lark_gram = TopLevelGrammarExt::from_lark_utf8(grammar); + selected = Some(lark_gram); + } + + if let Some(tag) = &structured.structural_tag { + constraint_count += 1; + if constraint_count > 1 { + crate::log_error!("[llg] Multiple constraints specified - structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag"); + return Err(candle_core::Error::msg("structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag")); + } + crate::log_debug!("[llg] Building tool call grammar from structural_tag"); + let (start, end, schema) = crate::tools::schema::parse_structural_tag(tag) + .map_err(|e| candle_core::Error::msg(e))?; + let schema = crate::tools::schema::sanitize_schema_for_llguidance(&schema); + // Convert schema Value to Vec for build_json_tool_lark_grammar + let tools = schema_to_tools(&schema); + // structural_tag uses text-based matching, pass None for token IDs + let tool_gram = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag(&start) + .end_tag(&end) + .start_is_special(false) + .end_is_special(false) + .build_json(); + selected = Some(tool_gram); + } + + if selected.is_none() { + crate::log_error!("[llg] No constraint specified in structured_outputs - must set exactly one of choice, regex, json, grammar, or structural_tag"); + return Err(candle_core::Error::msg("structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag")); + } + + crate::log_info!("[llg] grammar_fragment_from_structured_outputs() completed with grammar: {:?}", selected.is_some()); + Ok(selected) +} + +pub fn grammar_fragment_from_response_format(response_format: &ResponseFormat) -> Result> { + crate::log_debug!("[llg] grammar_fragment_from_response_format() called with type: {}", response_format.format_type); + + match response_format.format_type.as_str() { + "json_schema" => { + let Some(schema) = response_format.json_schema.as_ref() else { + crate::log_error!("[llg] response_format.json_schema is required for type=json_schema"); + return Err(candle_core::Error::msg("response_format.json_schema is required")); + }; + crate::log_debug!("[llg] Building JSON schema grammar from response_format"); + let schema = crate::tools::schema::sanitize_schema_for_llguidance(&schema.schema); + let json_gram = TopLevelGrammarExt::from_json_schema_utf8(schema) + .map_err(|e| candle_core::Error::msg(e.to_string()))?; + crate::log_info!("[llg] grammar_fragment_from_response_format() completed with grammar"); + Ok(Some(json_gram)) + } + other => { + crate::log_error!("[llg] Unsupported response_format type '{}'; only 'json_schema' is supported", other); + Err(candle_core::Error::msg(format!( + "Unsupported response_format type '{}'; only 'json_schema' is supported", + other + ))) + } + } +} + #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "snake_case")] pub enum EmbeddingStrategy { @@ -602,6 +780,14 @@ pub struct Args { /// MCP server arguments (comma-separated) #[arg(long, value_delimiter = ',', default_value = None)] pub mcp_args: Option>, + + /// Allow client-submitted constraints via HTTP API + #[arg(long, default_value = "false")] + pub allow_constraint_api: bool, + + /// Whether to automatically build LLG grammar from tools + #[arg(long, default_value = "false")] + pub enable_tool_grammar: bool, } /// Result of executing tool calls via MCP @@ -1249,10 +1435,7 @@ mod tests { #[test] fn test_chat_completion_tool_choice_required_parsing() { - let json = r#"{ - "messages": [{"role":"user","content":"hi"}], - "tool_choice": "required" - }"#; + let json = r#"{"messages": [{"role":"user","content":"hi"}], "tool_choice": "required"}"#; let request: ChatCompletionRequest = serde_json::from_str(json).unwrap(); assert!(matches!( request.tool_choice, @@ -1261,4 +1444,138 @@ mod tests { )) )); } + + #[test] + fn test_grammar_fragment_from_structured_outputs_choice() { + let so = StructuredOutputs { + choice: Some(vec!["option1".to_string(), "option2".to_string()]), + regex: None, + json: None, + grammar: None, + structural_tag: None, + }; + let result = grammar_fragment_from_structured_outputs(&so); + assert!(result.is_ok()); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_grammar_fragment_from_structured_outputs_json() { + let so = StructuredOutputs { + choice: None, + regex: None, + json: Some(serde_json::json!({"type": "object", "properties": {}})), + grammar: None, + structural_tag: None, + }; + let result = grammar_fragment_from_structured_outputs(&so); + assert!(result.is_ok()); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_grammar_fragment_from_structured_outputs_regex() { + let so = StructuredOutputs { + choice: None, + regex: Some("^[a-z]+$".to_string()), + json: None, + grammar: None, + structural_tag: None, + }; + let result = grammar_fragment_from_structured_outputs(&so); + assert!(result.is_ok()); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_grammar_fragment_from_structured_outputs_grammar() { + let so = StructuredOutputs { + choice: None, + regex: None, + json: None, + // Grammar without start: - that's managed by ComposedGrammar + grammar: Some("'hello' 'world'".to_string()), + structural_tag: None, + }; + let result = grammar_fragment_from_structured_outputs(&so); + assert!(result.is_ok()); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_grammar_fragment_from_structured_outputs_empty() { + let so = StructuredOutputs { + choice: None, + regex: None, + json: None, + grammar: None, + structural_tag: None, + }; + let result = grammar_fragment_from_structured_outputs(&so); + assert!(result.is_err()); + } + + #[test] + fn test_grammar_fragment_from_structured_outputs_too_many() { + let so = StructuredOutputs { + choice: Some(vec!["a".to_string()]), + regex: Some("b".to_string()), + json: None, + grammar: None, + structural_tag: None, + }; + let result = grammar_fragment_from_structured_outputs(&so); + assert!(result.is_err()); + } + + #[test] + fn test_grammar_fragment_from_response_format_json_schema() { + let rf = ResponseFormat { + format_type: "json_schema".to_string(), + json_schema: Some(ResponseFormatJsonSchema { + name: None, + schema: serde_json::json!({"type": "object", "properties": {}}), + }), + }; + let result = grammar_fragment_from_response_format(&rf); + assert!(result.is_ok()); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_grammar_fragment_from_response_format_missing_json_schema() { + let rf = ResponseFormat { + format_type: "json_schema".to_string(), + json_schema: None, + }; + let result = grammar_fragment_from_response_format(&rf); + assert!(result.is_err()); + } + + #[test] + fn test_grammar_fragment_from_response_format_unsupported_type() { + let rf = ResponseFormat { + format_type: "unsupported".to_string(), + json_schema: None, + }; + let result = grammar_fragment_from_response_format(&rf); + assert!(result.is_err()); + } + + #[test] + fn test_grammar_fragment_from_response_format_json_schema_composed() { + // Test that json_schema grammars pass through ComposedGrammar + let rf = ResponseFormat { + format_type: "json_schema".to_string(), + json_schema: Some(ResponseFormatJsonSchema { + name: None, + schema: serde_json::json!({"type": "object", "properties": {"test": {"type": "string"}}}), + }), + }; + let result = grammar_fragment_from_response_format(&rf); + assert!(result.is_ok()); + // The grammar was created via ComposedGrammar - just verify it's Some + let grammar = result.unwrap(); + assert!(grammar.is_some()); + } } diff --git a/src/server/parser.rs b/src/server/parser.rs index a95ab859..f771e1b9 100644 --- a/src/server/parser.rs +++ b/src/server/parser.rs @@ -50,6 +50,8 @@ pub struct ToolConfig { pub end_token_ids: HashSet, pub start_token_str: String, pub end_token_str: String, + pub start_is_special: bool, + pub end_is_special: bool, } impl ToolConfig { @@ -68,6 +70,8 @@ impl ToolConfig { end_token_ids: end_ids, start_token_str: "<|python_tag|>".to_string(), end_token_str: "<|eom_id|>".to_string(), + start_is_special: false, + end_is_special: false, } } ModelType::Qwen3 @@ -83,6 +87,8 @@ impl ToolConfig { end_token_ids: end_ids, start_token_str: "".to_string(), end_token_str: "".to_string(), + start_is_special: false, + end_is_special: false, } } ModelType::Mistral | ModelType::Mistral3VL => { @@ -93,6 +99,8 @@ impl ToolConfig { end_token_ids: end_ids, start_token_str: "[TOOL_CALLS]".to_string(), end_token_str: "]".to_string(), + start_is_special: false, + end_is_special: false, } } ModelType::Gemma | ModelType::Gemma3 => { @@ -102,6 +110,8 @@ impl ToolConfig { end_token_ids: end_ids, start_token_str: "".to_string(), end_token_str: "".to_string(), + start_is_special: false, + end_is_special: false, } } // Phi, GLM, Yi, StableLM, DeepSeek - use Qwen format (text-only) @@ -116,6 +126,8 @@ impl ToolConfig { end_token_ids: HashSet::new(), start_token_str: "".to_string(), end_token_str: "".to_string(), + start_is_special: false, + end_is_special: false, }, } } @@ -1046,7 +1058,7 @@ impl StreamToolParser { serde_json::from_str::>(trimmed).is_ok() } - fn parser_name_for_model(model_type: &ModelType, model_id: &str) -> &'static str { + pub fn parser_name_for_model(model_type: &ModelType, model_id: &str) -> &'static str { let model_lower = model_id.to_ascii_lowercase(); match model_type { ModelType::LLaMa => "llama", diff --git a/src/server/server.rs b/src/server/server.rs index eb88978a..69000061 100644 --- a/src/server/server.rs +++ b/src/server/server.rs @@ -1,6 +1,9 @@ // src/server/server.rs use super::logger::ChatCompletionLogger; +use crate::utils::guidance::{compose_grammars, get_lark_from_top_level_grammar, TopLevelGrammarExt}; +use llguidance::api::TopLevelGrammar; use super::{ + grammar_fragment_from_structured_outputs, grammar_fragment_from_response_format, build_messages_and_images, streaming::{ChatResponse, Streamer, StreamingStatus}, ChatResponder, DetokenizeRequest, DetokenizeResponse, EmbeddingRequest, EmbeddingResponse, @@ -16,8 +19,10 @@ use crate::server::parser::{BufferedFinalizeResult, StreamResult, StreamToolPars use crate::tools::helpers::{ build_invalid_tool_call_feedback, build_tool_schema_map, filter_tool_calls, log_tool_calls, resolve_tools, retain_tool_calls_forced_name, strict_tool_call_validation_enabled, + sanitize_tools_for_llguidance, }; use crate::tools::{ToolChoice, ToolChoiceMode}; +use crate::tools::schema::ToolGrammarBuilder; use crate::utils::config::SamplingParams; use axum::{ extract::{Json, Query, State}, @@ -32,6 +37,7 @@ use tokio::sync::watch; use tokio::task; use uuid::Uuid; + /// Helper struct to manage streaming response chunks /// Provides clean API for sending tokens, errors, and status notifications struct StreamingContext { @@ -275,14 +281,124 @@ pub async fn chat_completion( params.session_id = request.session_id.clone(); params.thinking = request.thinking.clone(); let (img_cfg, model_type, tool_config, engine_config) = { - let e = data.engine.read(); - ( - e.img_cfg.clone(), - e.model_type.clone(), - e.tool_config.clone(), - e.econfig.clone(), - ) - }; + let e = data.engine.read(); + ( + e.img_cfg.clone(), + e.model_type.clone(), + e.tool_config.clone(), + e.econfig.clone(), + ) + }; + let model_type = model_type.clone(); // Clone for later use + + // Collect all TopLevelGrammars from various sources + let mut constraint_grammars: Vec = Vec::new(); + + // Handle client-submitted constraints via structured_outputs or response_format + // First check top-level structured_outputs for convenience + if let Some(ref structured) = request.structured_outputs { + if engine_config.allow_constraint_api { + match grammar_fragment_from_structured_outputs(structured) { + Ok(Some(grammar)) => { + constraint_grammars.push(grammar); + crate::log_debug!("[llg] Collected constraint grammar from top-level structured_outputs"); + } + Ok(None) => { + // No constraint specified + } + Err(err) => { + crate::log_error!("[llg] Failed to parse structured_outputs: {:?}", err); + return ChatResponder::ValidationError(format!("{:?}", err)); + } + } + } else { + crate::log_warn!("[llg] Client-submitted constraints are disabled. Set --allow-constraint-api to enable."); + } + } + // Fallback to extra_body.structured_outputs for backwards compatibility + else if let Some(ref extra_body) = request.extra_body { + if let Some(ref structured) = extra_body.structured_outputs { + if engine_config.allow_constraint_api { + match grammar_fragment_from_structured_outputs(structured) { + Ok(Some(grammar)) => { + constraint_grammars.push(grammar); + crate::log_debug!("[llg] Collected constraint grammar from extra_body.structured_outputs"); + } + Ok(None) => { + // No constraint specified + } + Err(err) => { + crate::log_error!("[llg] Failed to parse structured_outputs: {:?}", err); + return ChatResponder::ValidationError(format!("{:?}", err)); + } + } + } else { + crate::log_warn!("[llg] Client-submitted constraints are disabled. Set --allow-constraint-api to enable."); + } + } + } + + if let Some(ref response_format) = request.response_format { + if engine_config.allow_constraint_api { + match grammar_fragment_from_response_format(response_format) { + Ok(Some(grammar)) => { + constraint_grammars.push(grammar); + crate::log_debug!("[llg] Collected constraint grammar from response_format"); + } + Ok(None) => { + // No constraint specified + } + Err(err) => { + crate::log_error!("[llg] Failed to parse response_format: {:?}", err); + return ChatResponder::ValidationError(format!("{:?}", err)); + } + } + } else { + crate::log_warn!("[llg] Client-submitted constraints are disabled. Set --allow-constraint-api to enable."); + } + } + + // Legacy constraint field (PROTECTED by allow_constraint_api flag) + if engine_config.allow_constraint_api { + if let Some(ref grammar_str) = request.constraint { + let constraint_type = request.constraint_type.as_deref().unwrap_or("regex"); + match constraint_type { + "regex" => { + let llg_grammar = TopLevelGrammarExt::from_regex_ascii(grammar_str); + constraint_grammars.push(llg_grammar); + crate::log_debug!("[llg] Generated regex constraint"); + } + "lark" => { + let llg_grammar = TopLevelGrammarExt::from_lark_utf8(grammar_str); + constraint_grammars.push(llg_grammar); + crate::log_debug!("[llg] Generated lark constraint"); + } + "json_schema" | "json" => { + match serde_json::from_str::(grammar_str) { + Ok(val) => { + match TopLevelGrammarExt::from_json_schema_utf8(val) { + Ok(llg_grammar) => { + constraint_grammars.push(llg_grammar); + crate::log_debug!("[llg] Generated json_schema constraint"); + } + Err(e) => { + crate::log_warn!("[llg] Failed to parse json_schema constraint: {:?}", e); + } + } + } + Err(e) => { + crate::log_warn!("[llg] Failed to parse json_schema constraint: {:?}", e); + } + } + } + _ => { + crate::log_warn!("[llg] Unknown constraint_type: {}", constraint_type); + } + } + } + } else { + crate::log_warn!("[llg] Client-submitted constraints are disabled. Set --allow-constraint-api to enable."); + } let mcp_tools = data .mcp_manager @@ -334,10 +450,17 @@ pub async fn chat_completion( } } - let tool_schemas = Arc::new(build_tool_schema_map(&resolved_tools)); + // Sanitize tools before building schema map to ensure ASCII-only tool names + let sanitized_tools = sanitize_tools_for_llguidance(&resolved_tools); + let tool_schemas = Arc::new(build_tool_schema_map(&sanitized_tools)); let has_tools = !resolved_tools.is_empty(); params.mcp_mode = if has_tools { Some(true) } else { None }; + // Compose all grammars using compose_grammars from guidance.rs + // Clone forced_tool_name for later use in retain_tool_calls_forced_name + let forced_tool_name_clone = forced_tool_name.clone(); + + if has_tools { crate::log_warn!("Tools enabled for request"); } @@ -347,8 +470,44 @@ pub async fn chat_completion( return ChatResponder::ValidationError(err); } let parser_model_id = - super::resolve_engine_model_id(&engine_config).unwrap_or_else(|| model_id.clone()); - let enforce_parser = engine_config.enforce_parser.clone(); + super::resolve_engine_model_id(&engine_config).unwrap_or_else(|| model_id.clone()); + let enforce_parser = engine_config.enforce_parser.clone(); + + // Build tool grammar based on parser type (XML for qwen_coder, JSON for others) + // Honor parser override flag (--enforce-parser) when available + let tool_parser_name = if let Some(ref enforced) = enforce_parser { + enforced.clone() + } else { + StreamToolParser::parser_name_for_model(&model_type, &parser_model_id).to_string() + }; + let use_xml_grammar = tool_parser_name == "qwen_coder"; + let tool_gram = if has_tools && engine_config.enable_tool_grammar { + let tool_gram = if use_xml_grammar { + crate::tools::schema::build_xml_tool_lark_grammar( + &sanitized_tools, + &tool_config.start_token_str, + &tool_config.end_token_str, + tool_config.start_is_special, + tool_config.end_is_special, + Some(&tool_config.start_token_ids), + Some(&tool_config.end_token_ids), + ) + } else { + ToolGrammarBuilder::new() + .tools(&sanitized_tools) + .start_tag(&tool_config.start_token_str) + .end_tag(&tool_config.end_token_str) + .start_is_special(tool_config.start_is_special) + .end_is_special(tool_config.end_is_special) + .start_token_ids(Some(tool_config.start_token_ids.clone())) + .end_token_ids(Some(tool_config.end_token_ids.clone())) + .build_json() + }; + crate::log_debug!("[llg] Built tool grammar (use_xml_grammar={})", use_xml_grammar); + Some(tool_gram) + } else { + None + }; let (messages, image_data) = match build_messages_and_images(&chat_messages, img_cfg.as_ref()) { Ok(output) => output, @@ -363,6 +522,28 @@ pub async fn chat_completion( .unwrap() .as_millis() as u64; + if constraint_grammars.is_empty() && !engine_config.enable_tool_grammar { + crate::log_debug!("[llg] No constraint or tool grammar - not setting guidance"); + } else { + // Get SpecialTokens from engine for building TEXT pattern with EOS bounding + let engine = data.engine.read(); + let special_tokens = &engine.special_tokens; + let llg_grammar = compose_grammars( + constraint_grammars, + tool_gram, + has_tools, + tool_choice_required, + forced_tool_name.clone(), + Some(max_tokens.clone()), + special_tokens, + ); + drop(engine); // Explicitly drop the lock guard + let lark_string = get_lark_from_top_level_grammar(&llg_grammar); + crate::log_debug!("[llg] TopLevelGrammar for SamplingParams: {:?}", &llg_grammar); + crate::log_debug!("[llg] Lark grammar string:\n{}", lark_string); + params.grammar = Some(llg_grammar); + } + if use_stream { let session_id = params.session_id.clone(); if let Some(sid) = session_id { @@ -401,7 +582,6 @@ pub async fn chat_completion( enforce_parser.clone(), ); tool_parser.set_initial_reasoning_end_marker(prefilled_reasoning_end); - let forced_tool_name = forced_tool_name.clone(); let stream_tool_schemas = tool_schemas.clone(); if let Some(ref l) = logger { l.log_start_response(); @@ -725,7 +905,7 @@ pub async fn chat_completion( let dropped = retain_tool_calls_forced_name( &mut pending_tool_calls, - forced_tool_name.as_deref(), + forced_tool_name_clone.as_deref(), ); if dropped > 0 { crate::log_warn!( @@ -752,7 +932,7 @@ pub async fn chat_completion( let invalid_feedback = build_invalid_tool_call_feedback( &invalid_calls, stream_tool_schemas.as_ref(), - forced_tool_name.as_deref(), + forced_tool_name_clone.as_deref(), ); let (valid_calls, invalid_feedback) = if !invalid_calls.is_empty() @@ -1019,7 +1199,7 @@ pub async fn chat_completion( .parse_complete_with_fallback(&output.decode_output) .await; let dropped = - retain_tool_calls_forced_name(&mut parsed_calls, forced_tool_name.as_deref()); + retain_tool_calls_forced_name(&mut parsed_calls, forced_tool_name_clone.as_deref()); if dropped > 0 { crate::log_warn!( "Dropped {} tool call(s) that did not match forced tool_choice", @@ -1036,7 +1216,7 @@ pub async fn chat_completion( let invalid_feedback = build_invalid_tool_call_feedback( &invalid_calls, tool_schemas.as_ref(), - forced_tool_name.as_deref(), + forced_tool_name_clone.as_deref(), ); let valid_calls = validated_calls; diff --git a/src/tools/helpers.rs b/src/tools/helpers.rs index bcd1706b..6c810098 100644 --- a/src/tools/helpers.rs +++ b/src/tools/helpers.rs @@ -34,6 +34,16 @@ pub fn strict_tool_call_validation_enabled() -> bool { }) } +pub fn sanitize_tools_for_llguidance(tools: &[Tool]) -> Vec { + tools.iter().map(sanitize_tool_schema).collect() +} + +fn sanitize_tool_schema(tool: &Tool) -> Tool { + let mut tool = tool.clone(); + tool.function.parameters = crate::tools::schema::sanitize_schema_for_llguidance(&tool.function.parameters); + tool +} + /// Build a map of tool names to their parameter schemas pub fn build_tool_schema_map(tools: &[Tool]) -> HashMap { tools diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 4270fa08..8e4cbfe9 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -23,7 +23,7 @@ pub struct ToolBuilder { } impl ToolBuilder { - fn new(name: String, description: String) -> Self { + pub fn new(name: String, description: String) -> Self { Self { name, description, @@ -259,22 +259,45 @@ impl ToolFormat { let config = ToolConfig::for_model_type(model_type); let start_tag = &config.start_token_str; let end_tag = &config.end_token_str; - let rule = format!( - "MOST IMPORTANT INSTRUCTION, **MUST** FOLLOW: For each function call, you MUST wrap function name and arguments in {start_tag}{end_tag} tags.\n\n\ - Do NOT USE ANY code blocks. Required format:\n\ - {start_tag}\n\ - {{\"name\": \"\", \"arguments\": }}\n\ - {end_tag}\n\n\ - Rules:\n\ - - Wrap function name and arguments with {start_tag} and {end_tag} tags\n\ - - Always use the exact {start_tag}{end_tag} format shown above\n\ - - Do NOT USE ANY code blocks\n\ - - Tool-use must be placed **at the end** of your response (**AFTER REASONING**), **top-level**, and not nested within other tags.\n\ - - Always adhere to this format for the tool use to ensure proper parsing and execution.\n\ - - The \"name\" and \"arguments\" are necessary fields\n\ - - DO NOT call ANY functions that DOES NOT defined between and \n\ - - MUST FOLLOW the above instruction when using tool call!", - ); - rule + match model_type { + crate::utils::config::ModelType::Qwen3 + | crate::utils::config::ModelType::Qwen3MoE + | crate::utils::config::ModelType::Qwen3VL => { + format!( + "MOST IMPORTANT INSTRUCTION, **MUST** FOLLOW: For each function call, you MUST use the QwenCoder tool format.\n\n\ + Required format:\n\ + {start_tag}\n\ + >\n\ + >\n\ + ...\n\ + \n\ + {end_tag}\n\n\ + Rules:\n\ + - Wrap tool calls with {start_tag} and {end_tag}\n\ + - Use and tags\n\ + - Each value MUST be valid JSON (string/object/array/number/bool)\n\ + - Do NOT USE ANY code blocks\n\ + - Tool-use must be placed at the end of your response (after reasoning)\n\ + - Only call tools defined between and \n\ + - MUST FOLLOW the above instruction when using tool call!", + ) + } + _ => format!( + "MOST IMPORTANT INSTRUCTION, **MUST** FOLLOW: For each function call, you MUST wrap function name and arguments in {start_tag}{end_tag} tags.\n\n\ + Do NOT USE ANY code blocks. Required format:\n\ + {start_tag}\n\ + {{\"name\": \"\", \"arguments\": }}\n\ + {end_tag}\n\n\ + Rules:\n\ + - Wrap function name and arguments with {start_tag} and {end_tag} tags\n\ + - Always use the exact {start_tag}{end_tag} format shown above\n\ + - Do NOT USE ANY code blocks\n\ + - Tool-use must be placed **at the end** of your response (**AFTER REASONING**), **top-level**, and not nested within other tags.\n\ + - Always adhere to this format for the tool use to ensure proper parsing and execution.\n\ + - The \"name\" and \"arguments\" are necessary fields\n\ + - DO NOT call ANY functions that DOES NOT defined between and \n\ + - MUST FOLLOW the above instruction when using tool call!", + ), + } } } diff --git a/src/tools/parser.rs b/src/tools/parser.rs index 99b81c00..a8c09b76 100644 --- a/src/tools/parser.rs +++ b/src/tools/parser.rs @@ -5,7 +5,10 @@ use super::{new_tool_call, ToolCall}; use regex::Regex; -use serde_json::Value; +use serde::{de::{Deserializer, MapAccess, Visitor}}; +use serde_json::{Map, Value}; +use std::fmt; +use std::sync::OnceLock; /// Parser for extracting tool calls from model output text #[allow(dead_code)] @@ -47,12 +50,18 @@ impl ToolParser { /// Parse tool calls from model output /// Only parses tool calls from the final answer (after reasoning end markers) pub fn parse(&self, text: &str) -> Vec { - let mut calls = Vec::new(); let mut call_id = 0; // Extract only the final answer portion (after reasoning ends) let final_answer = Self::extract_final_answer(text); + // Mistral-style parsing: strip wrappers and parse JSON or JSON array. + let mut calls = parse_tool_calls_from_text(&final_answer, &mut call_id); + + if !calls.is_empty() { + return calls; + } + // Try Qwen format first if let Some(qwen_calls) = self.parse_qwen_format(&final_answer, &mut call_id) { calls.extend(qwen_calls); @@ -157,10 +166,8 @@ impl ToolParser { } if let Ok(parsed) = serde_json::from_str::(trimmed) { - if let Some(call) = self.value_to_tool_call(&parsed, call_id) { - calls.push(call); - } - } + calls.extend(self.value_to_tool_call(&parsed, call_id)); + } } } } @@ -179,10 +186,11 @@ impl ToolParser { // Simple approach: try to parse the entire text as JSON first if let Ok(parsed) = serde_json::from_str::(text.trim()) { - if let Some(call) = self.value_to_tool_call(&parsed, call_id) { - return Some(vec![call]); - } - } + let parsed_calls = self.value_to_tool_call(&parsed, call_id); + if parsed_calls.is_some() { + return Some(vec![parsed_calls.unwrap()]); + } + } // Look for JSON blocks in the text let mut depth = 0; @@ -202,9 +210,7 @@ impl ToolParser { if let Some(s) = start { let json_str = &text[s..=i]; if let Ok(parsed) = serde_json::from_str::(json_str) { - if let Some(call) = self.value_to_tool_call(&parsed, call_id) { - calls.push(call); - } + calls.extend(self.value_to_tool_call(&parsed, call_id).into_iter()); } } start = None; @@ -229,9 +235,7 @@ impl ToolParser { for cap in re.captures_iter(text) { if let Some(content) = cap.get(1) { if let Ok(parsed) = serde_json::from_str::(content.as_str().trim()) { - if let Some(call) = self.value_to_tool_call(&parsed, call_id) { - calls.push(call); - } + calls.extend(self.value_to_tool_call(&parsed, call_id).into_iter()); } } } @@ -303,6 +307,300 @@ impl ToolParser { } } +// --- Mistral-style tool parsing helpers --- + +// Accept either `{...}` **or** a `"stringified { ... }"` +fn flexible_args<'de, D>(d: D) -> std::result::Result +where + D: Deserializer<'de>, +{ + struct ArgVisitor; + + impl<'de> Visitor<'de> for ArgVisitor { + type Value = Value; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("an object or a JSON-encoded string containing an object") + } + + fn visit_map(self, mut m: M) -> std::result::Result + where + M: MapAccess<'de>, + { + let mut map = Map::new(); + while let Some((k, v)) = m.next_entry()? { + map.insert(k, v); + } + Ok(Value::Object(map)) + } + + fn visit_str(self, s: &str) -> std::result::Result + where + E: serde::de::Error, + { + serde_json::from_str(s).map_err(|e| E::custom(format!("inner JSON error: {e}"))) + } + } + + d.deserialize_any(ArgVisitor) +} + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +struct CalledFunctionParameters { + #[serde(alias = "function")] + name: String, + #[serde(alias = "arguments", deserialize_with = "flexible_args")] + parameters: Value, +} + +fn contains_tool_call_prefix(prefix: &str) -> bool { + prefix.contains("") + || prefix.contains("<|tool▁call▁begin|>") + || prefix.contains("<|python_tag|>") + || prefix.contains("[TOOL_CALLS]") +} + +fn process_model_specific_message(message: &str) -> String { + static DEEPSEEK_REGEX: OnceLock = OnceLock::new(); + static QWEN_REGEX: OnceLock = OnceLock::new(); + + let deepseek_regex = DEEPSEEK_REGEX.get_or_init(|| { + Regex::new( + r"(?s)<|tool▁call▁begin|>function<|tool▁sep|>(?P[^\n]+)\n```json\n(?P.+?)\n```<|tool▁call▁end|>", + ) + .unwrap() + }); + let qwen_regex = QWEN_REGEX + .get_or_init(|| Regex::new(r"(?s)(?P.*?)").unwrap()); + + if let Some(message) = message.strip_prefix("<|python_tag|>") { + message + .strip_suffix("<|eom_id|>") + .unwrap_or(message) + .to_string() + } else if qwen_regex.is_match(message) { + if let Some(caps) = qwen_regex.captures(message) { + let inner = caps.name("inner").unwrap().as_str(); + return inner.trim().to_string(); + } + message.to_string() + } else if let Some(message) = message + .strip_prefix("[TOOL_CALLS][") + .and_then(|s| s.strip_suffix("]")) + { + message.to_string() + } else if deepseek_regex.find(message).is_some() { + let mut calls = Vec::new(); + for caps in deepseek_regex.captures_iter(message) { + let name = caps + .name("name") + .map(|m| m.as_str().trim().to_string()) + .unwrap_or_default(); + let json_str = caps.name("json").map(|m| m.as_str().trim()).unwrap_or("{}"); + let arguments: Value = + serde_json::from_str(json_str).unwrap_or_else(|_| Value::Object(Map::new())); + let args_str = serde_json::to_string(&arguments).unwrap_or_else(|_| "{}".to_string()); + calls.push(new_tool_call( + format!("call_{}", calls.len()), + name, + args_str, + )); + } + serde_json::to_string(&calls).unwrap_or_else(|_| message.to_string()) + } else { + message.to_string() + } +} + +fn fix_broken_json(raw: &str) -> String { + if raw.contains(r#""arguments":"{"#) { + let tmp = raw.replacen(r#""arguments":"{"#, r#""arguments":{"#, 1); + tmp.replacen(r#"}"}"#, r#"}}"#, 1) + } else { + raw.to_string() + } +} + +fn json_value_to_tool_call(value: &Value, call_id: &mut usize) -> Option { + let name = value.get("name")?.as_str()?.to_string(); + let arguments = value.get("arguments")?; + let args_str = if arguments.is_string() { + arguments.as_str().unwrap_or("{}").to_string() + } else { + serde_json::to_string(arguments).ok()? + }; + + let call = new_tool_call( + format!("call_{}", call_id), + name, + args_str, + ); + *call_id += 1; + Some(call) +} + +/// Parse tool calls from a raw message string (handles model-specific wrappers). +pub fn parse_tool_calls_from_text(text: &str, call_id: &mut usize) -> Vec { + // First, handle explicit wrappers (may appear multiple times) + if text.contains("") { + let mut calls = Vec::new(); + if let Ok(re) = Regex::new(r"(?s)\s*(.*?)\s*") { + for cap in re.captures_iter(text) { + if let Some(inner) = cap.get(1) { + let inner = inner.as_str().trim(); + if let Ok(parsed) = serde_json::from_str::(inner) { + if let Some(call) = json_value_to_tool_call(&parsed, call_id) { + calls.push(call); + } + continue; + } + + if let Some(call) = parse_function_tag_tool_call(inner, call_id) { + calls.push(call); + } + } + } + } + if !calls.is_empty() { + return calls; + } + } + + let processed = process_model_specific_message(text); + let processed = fix_broken_json(&processed); + + if let Ok(deser) = serde_json::from_str::(&processed) { + let args = serde_json::to_string(&deser.parameters).unwrap_or_else(|_| "{}".to_string()); + let call = new_tool_call( + format!("call_{}", call_id), + deser.name, + args, + ); + *call_id += 1; + return vec![call]; + } + + if let Ok(deser) = serde_json::from_str::>(&processed) { + let mut out = Vec::new(); + for item in deser { + let args = serde_json::to_string(&item.parameters).unwrap_or_else(|_| "{}".to_string()); + out.push(new_tool_call( + format!("call_{}", call_id), + item.name, + args, + )); + *call_id += 1; + } + return out; + } + + Vec::new() +} + +/// Checks if the given prefix could be the start of, or the entire JSON serialization of a tool call. +/// Returns (could_be_tool, is_complete_tool). +pub fn prefix_could_be_tool(prefix: &str) -> (bool, bool) { + if prefix.trim().is_empty() { + return (false, false); + } + + // If we already have a full ..., attempt to parse directly. + if prefix.contains("") { + let mut call_id = 0; + if !parse_tool_calls_from_text(prefix, &mut call_id).is_empty() { + return (false, true); + } + } + + // If we see a start tag, it's at least a potential tool call. + if prefix.contains("") { + return (true, false); + } + + let processed = process_model_specific_message(prefix); + let processed = fix_broken_json(&processed); + + let checks = [ + could_be_json::, + could_be_json::>, + ]; + + for check in checks { + let (could_be, complete) = check(&processed); + if could_be || complete { + return (could_be, complete); + } + } + + ( + contains_tool_call_prefix(prefix) || contains_tool_call_prefix(&processed), + false, + ) +} + +fn could_be_json(text_prefix: &str) -> (bool, bool) +where + T: serde::de::DeserializeOwned, +{ + if text_prefix.trim().is_empty() { + return (false, false); + } + match serde_json::from_str::(text_prefix) { + Ok(_) => (false, true), + Err(e) if e.is_eof() => (true, false), + _ => (false, false), + } +} + +fn parse_function_tag_tool_call(inner: &str, call_id: &mut usize) -> Option { + let func_tag = "')? + name_start; + if name_end <= name_start { + return None; + } + let func_name = inner[name_start..name_end].trim(); + if func_name.is_empty() { + return None; + } + + let mut params = Map::new(); + let mut pos = name_end + 1; + while let Some(param_tag_pos) = inner[pos..].find("") + .map(|v| v + value_start)?; + if value_end <= value_start { + break; + } + let value_raw = inner[value_start..value_end].trim(); + let value = serde_json::from_str::(value_raw) + .unwrap_or_else(|_| Value::String(value_raw.to_string())); + params.insert(key.to_string(), value); + pos = value_end + "".len(); + } + + let args = Value::Object(params); + let args_str = serde_json::to_string(&args).ok()?; + + let call = new_tool_call( + format!("call_{}", call_id), + func_name.to_string(), + args_str, + ); + *call_id += 1; + Some(call) +} + #[cfg(test)] mod tests { use super::*; @@ -372,4 +670,25 @@ mod tests { assert!(parser.has_tool_calls(r#"{"name": "foo", "arguments": {}}"#)); assert!(!parser.has_tool_calls("Just a normal response")); } + + #[test] + fn test_parse_function_tag_format() { + let parser = ToolParser::new(); + let text = r#" + + +{"bar": 1} + + +qux + + +"#; + + let calls = parser.parse(text); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].function.name, "my_tool"); + assert!(calls[0].clone().function.arguments.unwrap().contains("\"foo\"")); + assert!(calls[0].clone().function.arguments.unwrap().contains("\"baz\"")); + } } diff --git a/src/tools/schema.rs b/src/tools/schema.rs index fa5ebce0..76e1a450 100644 --- a/src/tools/schema.rs +++ b/src/tools/schema.rs @@ -3,8 +3,314 @@ //! //! Provides helpers for working with JSON Schema in tool definitions. -use serde_json::{json, Value}; -use std::collections::HashMap; +use crate::tools::Tool; +use serde_json::{json, Map, Value}; +use std::collections::{HashMap, HashSet}; +use crate::utils::guidance::{TopLevelGrammarExt, GrammarError, GrammarResult}; +use llguidance::api::TopLevelGrammar; + +/// Remove JSON Schema features that llguidance doesn't support. +/// Currently strips all "format" fields recursively. +pub fn sanitize_schema_for_llguidance(schema: &Value) -> Value { + match schema { + Value::Object(map) => { + let mut out = Map::new(); + for (key, value) in map { + if key == "format" { + continue; + } + out.insert(key.clone(), sanitize_schema_for_llguidance(value)); + } + Value::Object(out) + } + Value::Array(items) => { + Value::Array(items.iter().map(sanitize_schema_for_llguidance).collect()) + } + _ => schema.clone(), + } +} + +/// Lark grammar helper functions for llguidance constraint building +/// Sanitize string for Lark grammar - only allow ASCII characters +fn lark_quote(value: &str) -> String { + // Strip non-ASCII characters to prevent grammar parser errors + let sanitized: String = value + .chars() + .filter(|c| c.is_ascii()) + .collect(); + let escaped = sanitized.replace('\\', "\\\\").replace('"', "\\\""); + format!("\"{}\"", escaped) +} + +/// Convert token IDs to Lark special token syntax <[token_id]> +/// This is used when the tokenizer has canonical tokenization for the tag +fn lark_special_token(token_ids: &HashSet) -> String { + if token_ids.is_empty() { + return String::new(); + } + // Join multiple token IDs with | + let ids: Vec = token_ids.iter().map(|id| format!("[{}]", id)).collect(); + format!("<{}>", ids.join(",")) +} + +fn _lark_literal(value: &str, is_special: bool) -> String { + if is_special && value.starts_with('<') && value.ends_with('>') { + // Only allow ASCII special tags + let sanitized: String = value + .chars() + .filter(|c| c.is_ascii()) + .collect(); + sanitized + } else { + lark_quote(value) + } +} + +/// Builder for constructing tool call grammars +pub struct ToolGrammarBuilder { + tools: Vec, + start_tag: String, + end_tag: String, + start_is_special: bool, + end_is_special: bool, + start_token_ids: Option>, + end_token_ids: Option>, +} + +impl ToolGrammarBuilder { + pub fn new() -> Self { + Self { + tools: Vec::new(), + start_tag: String::new(), + end_tag: String::new(), + start_is_special: false, + end_is_special: false, + start_token_ids: None, + end_token_ids: None, + } + } + + pub fn tools(mut self, tools: &[Tool]) -> Self { + self.tools.extend(tools.iter().cloned()); + self + } + + pub fn start_tag(mut self, tag: impl Into) -> Self { + self.start_tag = tag.into(); + self + } + + pub fn end_tag(mut self, tag: impl Into) -> Self { + self.end_tag = tag.into(); + self + } + + pub fn start_is_special(mut self, special: bool) -> Self { + self.start_is_special = special; + self + } + + pub fn end_is_special(mut self, special: bool) -> Self { + self.end_is_special = special; + self + } + + pub fn start_token_ids(mut self, ids: Option>) -> Self { + self.start_token_ids = ids; + self + } + + pub fn end_token_ids(mut self, ids: Option>) -> Self { + self.end_token_ids = ids; + self + } + + /// Build Lark expression for JSON tool schema content + pub fn build_json(self) -> TopLevelGrammar { + let mut rules = Vec::new(); + + let start_tag = self.get_tag_or_token_id(&self.start_tag, &self.start_token_ids, self.start_is_special); + let end_tag = self.get_tag_or_token_id(&self.end_tag, &self.end_token_ids, self.end_is_special); + + rules.push("start: tool_call".to_string()); + rules.push(format!("tool_call: {} tool_obj {}", start_tag, end_tag)); + rules.push("tool_obj: %json {\"type\":\"object\",\"properties\":{\"name\":{\"type\":\"string\"},\"arguments\":{\"type\":\"object\"}},\"required\":[\"name\",\"arguments\"]}".to_string()); + rules.push("json_array: \"[\" obj (\",\" obj)* \"]\"".to_string()); + + for tool in &self.tools { + let tool_name = tool.function.name.replace("-", "_"); + let schema_str = serde_json::to_string(&tool.function.parameters).unwrap_or_default(); + rules.push(format!("obj_{tool_name}: %json {schema_str}")); + } + + if rules.len() <= 4 { + rules.push("obj: %json {\"type\": \"object\"}".to_string()); + } else { + rules.extend(self.tools.iter().enumerate().map(|(_i, t)| { + let name = t.function.name.replace("-", "_"); + format!("obj_{name}: %json {}", serde_json::to_string(&t.function.parameters).unwrap_or_default()) + })); + + let obj_names = self.tools.iter().map(|t| { + format!("obj_{}", t.function.name.replace("-", "_")) + }).collect::>().join(" | "); + rules.push(format!("obj: {}", obj_names)); + } + + // rules.push(format!("ws: {}", lark_ws_regex())); + + let lark = rules.join("\n") + "\n"; + crate::log_debug!("[llg] ToolGrammarBuilder::build_json lark: {}", &lark); + TopLevelGrammar::from_lark_utf8(&lark) + } + + /// Build Lark expression for valid XML parameter content + fn build_xml_value_expression(schema: &serde_json::Value) -> String { + let param_type = schema.get("type").and_then(|t| t.as_str()).unwrap_or("string"); + + match param_type { + "string" => { + // Match any text content without look-around assertions + // Simple pattern: match any character except < or any < followed by non-slash + if let Ok(val) = std::env::var("VLLM_LLG_DEFAULT_XML_STR") { + format!("{}", val) + } else { + r#"/[^<]*/"#.to_string() + // r"/[^<]+(<[^/][^<]*)*/".to_string() + // ^^ nested tag capture produces infinite generation - limitation of XML + } + }, + "integer" => r"/-?[0-9]+/".to_string(), + "number" => r"/-?[0-9]+(\.[0-9]+)?/".to_string(), + "boolean" => r"/^(true|false)$/".to_string(), + "array" => r"/\[[^\]]*\]/".to_string(), + "object" => r"/\{[^\}]*\}/".to_string(), + _ => r"/(?s:.*)/".to_string(), + } + } + + /// Build Lark expression for XML tool schema content + pub fn build_xml(self) -> TopLevelGrammar { + let mut rules: Vec = Vec::new(); + + // Build envelope tag using token IDs when available + let envelope_start_tag = self.get_envelope_tag(&self.start_tag, &self.start_token_ids, self.start_is_special); + let envelope_end_tag = self.get_envelope_tag(&self.end_tag, &self.end_token_ids, self.end_is_special); + + let tool_rule_names: Vec = (0..self.tools.len()).map(|i| format!("tool_{i}")).collect(); + rules.push("start: tool_call".to_string()); + rules.push(format!("tool_call: {} tool_content {}", envelope_start_tag, envelope_end_tag)); + + // Get required params from schema + let get_required_params = |params_schema: &serde_json::Value| -> Vec { + params_schema.get("required") + .and_then(|r| r.as_array()) + .map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect()) + .unwrap_or_default() + }; + + for (tool_idx, tool) in self.tools.iter().enumerate() { + let tool_name_ascii: String = tool.function.name.chars().filter(|c| c.is_ascii()).collect(); + let func_start = lark_quote(&format!("", tool_name_ascii)); + let func_end = lark_quote(""); + let params_schema = &tool.function.parameters; + let props = params_schema.get("properties").and_then(|p| p.as_object()); + let required_params = get_required_params(params_schema); + + if let Some(props) = props { + let mut param_rules_vec: Vec = Vec::new(); + + for (param_idx, (param_name, schema)) in props.iter().enumerate() { + let param_name_ascii: String = param_name.chars().filter(|c| c.is_ascii()).collect(); + let param_tag = lark_quote(&format!("", param_name_ascii)); + let param_end = lark_quote(""); + let value_rule = format!("value_{tool_idx}_{param_idx}"); + let param_rule = format!("param_{tool_idx}_{param_idx}"); + + // Determine the Lark expression for valid XML content based on schema type + let value_expr = Self::build_xml_value_expression(schema); + rules.push(format!("{value_rule}: {value_expr}")); + rules.push(format!("{param_rule}: {param_tag} {value_rule} {param_end}")); + + // Add to param_rules_vec with ? for optional, bare for required + if required_params.contains(param_name) { + param_rules_vec.push(param_rule.clone()); + } else { + param_rules_vec.push(format!("({param_rule})?")); + } + } + + let params_expr = param_rules_vec.join(" "); + rules.push(format!("tool_{tool_idx}: {func_start} {params_expr} {func_end}")); + } else { + // No parameters - just function tags + rules.push(format!("tool_{tool_idx}: {func_start} {func_end}")); + } + } + + // Build tool_content with alternation of all tools + let tool_variants = tool_rule_names.join(" | "); + rules.push(format!("tool_content: {tool_variants}")); + // rules.push(format!("_WS: {}", lark_ws_regex())); + + let lark = rules.join("\n") + "\n"; + crate::log_debug!("[llg] ToolGrammarBuilder::build_json lark: {}", &lark); + TopLevelGrammar::from_lark_utf8(&lark) + } + + /// Get envelope tag (start/end) using token IDs when available, falling back to string literals + fn get_envelope_tag(&self, tag: &str, token_ids: &Option>, is_special: bool) -> String { + if let Some(ids) = token_ids { + if !ids.is_empty() { + return lark_special_token(ids); + } + } + + if is_special && tag.starts_with('<') && tag.ends_with('>') { + // Only allow ASCII special tags + let sanitized: String = tag.chars().filter(|c| c.is_ascii()).collect(); + sanitized + } else { + lark_quote(tag) + } + } + + fn get_tag_or_token_id(&self, tag: &str, token_ids: &Option>, is_special: bool) -> String { + if let Some(ids) = token_ids { + if !ids.is_empty() { + return format!("<{}>", ids.iter().map(|id| format!("[{}]", id)).collect::>().join(",")); + } + } + + if is_special && tag.starts_with('<') && tag.ends_with('>') { + tag.to_string() + } else { + lark_quote(tag) + } + } +} + +/// Build a Lark grammar for QwenCoder-style function/parameter tags with JSON values. +/// Used for models like Qwen3-Coder that use XML-style tool call envelopes. +pub fn build_xml_tool_lark_grammar( + tools: &[Tool], + start: &str, + end: &str, + start_is_special: bool, + end_is_special: bool, + start_token_ids: Option<&HashSet>, + end_token_ids: Option<&HashSet>, +) -> TopLevelGrammar { + ToolGrammarBuilder::new() + .tools(tools) + .start_tag(start) + .end_tag(end) + .start_is_special(start_is_special) + .end_is_special(end_is_special) + .start_token_ids(start_token_ids.cloned()) + .end_token_ids(end_token_ids.cloned()) + .build_xml() +} /// Builder for creating JSON Schema objects #[derive(Debug, Clone, Default)] @@ -251,3 +557,1005 @@ pub mod common { .build() } } + +/// Build a Lark grammar for choice constraints (structured outputs choice field) +pub fn build_choice_lark_grammar(choices: &[String]) -> GrammarResult { + if choices.is_empty() { + return Err(GrammarError::InvalidGrammar("structured_outputs.choice must include at least one option".to_string())); + } + + let mut parts = Vec::with_capacity(choices.len()); + for choice in choices { + if choice.is_empty() { + return Err(GrammarError::InvalidGrammar("structured_outputs.choice cannot contain empty strings".to_string())); + } + parts.push(lark_quote(choice)); + } + + let body = parts.join(" | "); + let lark_string = format!("start: {}\n", body); + Ok(TopLevelGrammar::from_lark_utf8(&lark_string)) +} + +/// Normalize a tag string for structural_tag parsing +fn normalize_tag_pair(tag: &str) -> Result<(String, String), String> { + let trimmed = tag.trim(); + if trimmed.is_empty() { + return Err("structured_outputs.structural_tag.tag cannot be empty".to_string()); + } + + if trimmed.starts_with('<') && trimmed.ends_with('>') { + let inner = trimmed + .trim_start_matches('<') + .trim_end_matches('>') + .trim_start_matches('/'); + if inner.is_empty() { + return Err("structured_outputs.structural_tag.tag is invalid".to_string()); + } + let start = if trimmed.starts_with("", inner) + } else { + trimmed.to_string() + }; + let end = format!("", inner); + Ok((start, end)) + } else { + Ok((format!("<{}>", trimmed), format!("", trimmed))) + } +} + +/// Parse structural_tag for structured outputs +pub fn parse_structural_tag(value: &Value) -> Result<(String, String, Value), String> { + let obj = value.as_object().ok_or_else(|| { + "structured_outputs.structural_tag must be an object".to_string() + })?; + + let schema = obj.get("schema").cloned().ok_or_else(|| { + "structured_outputs.structural_tag.schema is required".to_string() + })?; + + let start = obj.get("start_tag").or_else(|| obj.get("start")).or_else(|| obj.get("tag")); + let end = obj.get("end_tag").or_else(|| obj.get("end")); + + let (start_tag, end_tag) = match (start, end) { + (Some(start_val), Some(end_val)) => { + let start = start_val.as_str().ok_or_else(|| { + "structured_outputs.structural_tag.start_tag must be a string".to_string() + })?; + let end = end_val.as_str().ok_or_else(|| { + "structured_outputs.structural_tag.end_tag must be a string".to_string() + })?; + (start.to_string(), end.to_string()) + } + (Some(tag), None) if obj.contains_key("tag") => normalize_tag_pair(tag.as_str().ok_or_else(|| "structured_outputs.structural_tag.tag must be a string".to_string())?)?, + _ => { + return Err("structured_outputs.structural_tag requires tag or start_tag/end_tag".to_string()); + } + }; + + Ok((start_tag, end_tag, schema)) +} + +/// Convert a Value schema to a Vec of Tool objects using ToolBuilder +/// The schema should be an object where keys are tool names and values are tool schemas +pub fn schema_to_tools(schema: &Value) -> Vec { + let mut tools = Vec::new(); + if let Value::Object(obj) = schema { + for (name, tool_schema) in obj { + if let Value::Object(props) = tool_schema { + if let Some(params) = props.get("parameters") { + let builder = crate::tools::ToolBuilder::new(name.clone(), "".to_string()) + .parameters_schema(params.clone()); + tools.push(builder.build()); + } + } + } + } + tools +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::utils::guidance::get_lark_from_top_level_grammar; + + #[test] + fn test_sanitize_schema_for_llguidance_strips_format() { + let schema = json!({ + "type": "object", + "properties": { + "url": {"type": "string", "format": "uri"}, + "nested": {"type": "object", "properties": {"id": {"type": "string", "format": "uuid"}}} + } + }); + let sanitized = sanitize_schema_for_llguidance(&schema); + assert!(sanitized["properties"]["url"].get("format").is_none()); + assert!(sanitized["properties"]["nested"]["properties"]["id"].get("format").is_none()); + } + + #[test] + fn test_build_choice_lark_grammar_empty_string() { + let result = build_choice_lark_grammar(&["".to_string()]); + assert!(result.is_err()); + } + + #[test] + fn test_parse_structural_tag_missing_schema() { + let value = json!({}); + let result = parse_structural_tag(&value); + assert!(result.is_err()); + } + + #[test] + fn test_parse_structural_tag_start_end() { + let value = json!({ + "start_tag": "", + "end_tag": "", + "schema": {"type": "object"} + }); + let result = parse_structural_tag(&value); + assert!(result.is_ok()); + let (start, end, schema) = result.unwrap(); + assert_eq!(start, ""); + assert_eq!(end, ""); + assert_eq!(schema, json!({"type": "object"})); + } + + #[test] + fn test_parse_structural_tag_tag() { + let value = json!({ + "tag": "", + "schema": {"type": "object"} + }); + let result = parse_structural_tag(&value); + assert!(result.is_ok()); + let (start, end, _) = result.unwrap(); + assert_eq!(start, ""); + assert_eq!(end, ""); + } + + #[test] + fn test_parse_structural_tag_invalid() { + let value = json!({ + "schema": {"type": "object"} + }); + let result = parse_structural_tag(&value); + assert!(result.is_err()); + } + + #[test] + fn test_lark_quote_escapes_special_chars() { + let result = lark_quote("test\"value"); + assert!(result.contains("test\\\"value")); + } + + #[test] + fn test_lark_literal_special_tags() { + let result = _lark_literal("", true); + assert_eq!(result, ""); + } + + #[test] + fn test_lark_literal_regular_string() { + let result = _lark_literal("regular", false); + assert!(result.contains("\"regular\"")); + } + + #[test] + fn test_lark_special_token_single_id() { + let mut ids = HashSet::new(); + ids.insert(151657); + let result = lark_special_token(&ids); + assert_eq!(result, "<[151657]>"); + } + + #[test] + fn test_lark_special_token_multiple_ids() { + let mut ids = HashSet::new(); + ids.insert(151657); + ids.insert(151658); + let result = lark_special_token(&ids); + assert!(result.contains("[151657]")); + assert!(result.contains("[151658]")); + } + + #[test] + fn test_lark_special_token_empty() { + let ids = HashSet::new(); + let result = lark_special_token(&ids); + assert_eq!(result, ""); + } + + #[test] + fn test_build_xml_tool_lark_grammar_qwen3_coder_required_only() { + // Test Qwen3-Coder XML tool format with required attributes only + let tools = vec![ + crate::tools::ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query", "string", "Search query", true) + .build(), + ]; + let grammar = build_xml_tool_lark_grammar(&tools, "", "", false, false, None, None); + let lark_str = get_lark_from_top_level_grammar(&grammar); + println!("{}", &lark_str); + + // Qwen3Coder uses XML format with start: tool_call + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains(""), "Should contain function tag"); + assert!(lark_str.contains("tool_0:"), "Should contain tool_0 rule"); + } + + #[test] + fn test_build_xml_tool_lark_grammar_qwen3_coder_optional() { + // Test Qwen3-Coder XML tool format with optional attributes + let tools = vec![ + crate::tools::ToolBuilder::new("get_weather".to_string(), "Get weather".to_string()) + .param("city", "string", "City name", true) + .param("units", "string", "Temperature units (optional)", false) + .build(), + ]; + let grammar = build_xml_tool_lark_grammar(&tools, "", "", false, false, None, None); + let lark_str = get_lark_from_top_level_grammar(&grammar); + + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains(""), "Should contain function tag"); + assert!(lark_str.contains("city"), "Should contain city parameter"); + assert!(lark_str.contains("units"), "Should contain optional units parameter"); + } + + #[test] + fn test_build_xml_tool_lark_grammar_qwen3_coder_deep_parameters() { + // Test Qwen3-Coder XML tool format with nested/complex parameters + let tools = vec![ + crate::tools::ToolBuilder::new("edit_file".to_string(), "Edit a file with complex parameters".to_string()) + .param("file_path", "string", "Path to the file", true) + .param("old_string", "string", "String to replace", true) + .param("new_string", "string", "Replacement string", true) + .param("replace_all", "boolean", "Replace all occurrences", false) + .build(), + ]; + let grammar = build_xml_tool_lark_grammar(&tools, "", "", false, false, None, None); + let lark_str = get_lark_from_top_level_grammar(&grammar); + println!("XML Grammar:\n{}", &lark_str); + + // Verify the grammar contains XML structure + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + // Note: uses U+200C (zero-width non-joiner) which is invisible + assert!(lark_str.contains("function="), "Should contain function tag with attribute"); + + // Verify all parameter tags are present + // Note: uses U+200C (zero-width non-joiner) which is invisible + assert!(lark_str.contains("parameter=file_path"), "Should contain file_path parameter tag"); + assert!(lark_str.contains("parameter=old_string"), "Should contain old_string parameter tag"); + assert!(lark_str.contains("parameter=new_string"), "Should contain new_string parameter tag"); + assert!(lark_str.contains("parameter=replace_all"), "Should contain replace_all parameter tag"); + + // Verify parameter rules reference the correct types + assert!(lark_str.contains("param_0_0:"), "Should have param_0_0 rule for first param"); + assert!(lark_str.contains("param_0_1:"), "Should have param_0_1 rule for second param"); + assert!(lark_str.contains("param_0_2:"), "Should have param_0_2 rule for third param"); + assert!(lark_str.contains("param_0_3:"), "Should have param_0_3 rule for fourth param"); + + // Verify tool rule has all parameters + assert!(lark_str.contains("tool_0:"), "Should have tool_0 rule"); + } + + #[test] + fn test_xml_grammar_required_params_no_wrapper() { + // Test that XML grammar puts required params directly without (...) * wrapper + let tools = vec![crate::tools::ToolBuilder::new("search_tool".to_string(), "Search tool".to_string()) + .param("query", "string", "Search query", true) // REQUIRED - should appear as bare rule reference + .build()]; + + let grammar = build_xml_tool_lark_grammar(&tools, "", "", false, false, None, None); + let lark_str = get_lark_from_top_level_grammar(&grammar); + + // Required param rule should appear directly in tool_0 (no parentheses/asterisk around it) + assert!(lark_str.contains("tool_0:"), "Should have tool_0 rule"); + assert!(lark_str.contains("param_0"), "Should have parameter rules"); + + // The required param should NOT be wrapped in (...) * pattern + // Look for the pattern where required params appear as direct references: "param_X Y" not "(param_X | ...)*" + } + + #[test] + fn test_xml_grammar_optional_params_wrapped() { + // Test that XML grammar wraps optional params with (...) * syntax + let tools = vec![crate::tools::ToolBuilder::new("mixed_tool".to_string(), "Mixed params".to_string()) + .param("required_param", "string", "Required", true) // REQUIRED + .param("optional_param", "string", "Optional", false) // OPTIONAL + .build()]; + + let grammar = build_xml_tool_lark_grammar(&tools, "", "", false, false, None, None); + let lark_str = get_lark_from_top_level_grammar(&grammar); + + println!("XML Grammar for mixed tool:\n{}", lark_str); + + // Optional parameters should appear in a (...) * pattern when there are multiple options + assert!(lark_str.contains("tool_0:"), "Should have tool_0 rule"); + } + + #[test] + fn test_xml_tool_call_structure_validates() { + // Full end-to-end: verify XML grammar produces valid llguidance TopLevelGrammar structure + let tools = vec![crate::tools::ToolBuilder::new("formatter".to_string(), "Formatter".to_string()) + .param("text", "string", "Text to format", true) + .build()]; + + let grammar = build_xml_tool_lark_grammar(&tools, "", "", false, false, None, None); + + // Grammar should have at least one sub-grammar (the tool rules) + assert!(grammar.grammars.len() > 0, "Should have generated grammars"); + } + + // === ToolGrammarBuilder JSON Mode Tests === + + #[test] + fn test_tool_grammar_builder_build_json_single_tool() { + // Test ToolGrammarBuilder.build_json() with a single tool + let tools = vec![ + crate::tools::ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query", "string", "Search query", true) + .build(), + ]; + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("") + .end_tag("") + .start_is_special(false) + .end_is_special(false) + .build_json(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + + // Verify basic structure + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains("obj_search:"), "Should contain obj_search rule"); + assert!(lark_str.contains("query"), "Should contain query parameter"); + } + + #[test] + fn test_tool_grammar_builder_build_json_multiple_tools() { + // Test ToolGrammarBuilder.build_json() with multiple tools + let tools = vec![ + crate::tools::ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query", "string", "Search query", true) + .build(), + crate::tools::ToolBuilder::new("weather".to_string(), "Get weather".to_string()) + .param("city", "string", "City name", true) + .build(), + ]; + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("") + .end_tag("") + .start_is_special(false) + .end_is_special(false) + .build_json(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + + // Verify all tools are present + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains("obj_search:"), "Should contain obj_search rule"); + assert!(lark_str.contains("obj_weather:"), "Should contain obj_weather rule"); + // Verify obj alternation includes both tools + assert!(lark_str.contains("obj: obj_search | obj_weather"), "Should have obj alternation"); + } + + #[test] + fn test_tool_grammar_builder_build_json_with_token_ids() { + // Test ToolGrammarBuilder.build_json() with token IDs + let tools = vec![ + crate::tools::ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query", "string", "Search query", true) + .build(), + ]; + let mut start_ids = HashSet::new(); + start_ids.insert(151657); + let mut end_ids = HashSet::new(); + end_ids.insert(151658); + + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("") + .end_tag("") + .start_is_special(false) + .end_is_special(false) + .start_token_ids(Some(start_ids)) + .end_token_ids(Some(end_ids)) + .build_json(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + + // Verify token IDs are used + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains("<[151657]>"), "Should contain start token ID"); + assert!(lark_str.contains("<[151658]>"), "Should contain end token ID"); + } + + #[test] + fn test_tool_grammar_builder_build_json_with_special_tags() { + // Test ToolGrammarBuilder.build_json() with special tags + let tools = vec![ + crate::tools::ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query", "string", "Search query", true) + .build(), + ]; + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("") + .end_tag("") + .start_is_special(true) + .end_is_special(true) + .build_json(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + + // Verify special tags are used as-is + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains(""), "Should contain special start tag"); + assert!(lark_str.contains(""), "Should contain special end tag"); + } + + #[test] + fn test_tool_grammar_builder_build_json_required_optional() { + // Test ToolGrammarBuilder.build_json() with mix of required/optional params + let tools = vec![ + crate::tools::ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query", "string", "Search query", true) + .param("max_results", "integer", "Max results", false) + .build(), + ]; + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("") + .end_tag("") + .start_is_special(false) + .end_is_special(false) + .build_json(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + + // Verify both params are in schema, and required array is correct + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains("obj_search:"), "Should contain obj_search rule"); + assert!(lark_str.contains("query"), "Should contain query parameter"); + assert!(lark_str.contains("max_results"), "Should contain max_results parameter"); + assert!(lark_str.contains("\"required\""), "Should have required array"); + } + + // === ToolGrammarBuilder XML Mode Tests === + + #[test] + fn test_tool_grammar_builder_build_xml_single_tool() { + // Test ToolGrammarBuilder.build_xml() with a single tool + let tools = vec![ + crate::tools::ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query", "string", "Search query", true) + .build(), + ]; + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("") + .end_tag("") + .start_is_special(false) + .end_is_special(false) + .build_xml(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + + // Verify XML structure + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains("tool_call:"), "Should have tool_call rule"); + assert!(lark_str.contains("function=search"), "Should contain function tag"); + assert!(lark_str.contains("parameter=query"), "Should contain parameter tag"); + assert!(lark_str.contains("param_0_0:"), "Should have param_0_0 rule"); + } + + #[test] + fn test_tool_grammar_builder_build_xml_multiple_tools() { + // Test ToolGrammarBuilder.build_xml() with multiple tools + let tools = vec![ + crate::tools::ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query", "string", "Search query", true) + .build(), + crate::tools::ToolBuilder::new("weather".to_string(), "Get weather".to_string()) + .param("city", "string", "City name", true) + .build(), + ]; + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("") + .end_tag("") + .start_is_special(false) + .end_is_special(false) + .build_xml(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + + // Verify all tools are present + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains("tool_0:"), "Should contain tool_0 rule"); + assert!(lark_str.contains("tool_1:"), "Should contain tool_1 rule"); + assert!(lark_str.contains("tool_content:"), "Should have tool_content rule"); + // Verify tool_content has alternation + assert!(lark_str.contains("tool_content: tool_0 | tool_1"), "Should have tool alternation"); + } + + #[test] + fn test_tool_grammar_builder_build_xml_with_token_ids() { + // Test ToolGrammarBuilder.build_xml() with token IDs + let tools = vec![ + crate::tools::ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query", "string", "Search query", true) + .build(), + ]; + let mut start_ids = HashSet::new(); + start_ids.insert(151657); + let mut end_ids = HashSet::new(); + end_ids.insert(151658); + + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("") + .end_tag("") + .start_is_special(false) + .end_is_special(false) + .start_token_ids(Some(start_ids)) + .end_token_ids(Some(end_ids)) + .build_xml(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + + // Verify token IDs are used for envelope tags + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains("<[151657]>"), "Should contain start token ID"); + assert!(lark_str.contains("<[151658]>"), "Should contain end token ID"); + } + + #[test] + fn test_tool_grammar_builder_build_xml_with_special_tags() { + // Test ToolGrammarBuilder.build_xml() with special tags + let tools = vec![ + crate::tools::ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query", "string", "Search query", true) + .build(), + ]; + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("") + .end_tag("") + .start_is_special(true) + .end_is_special(true) + .build_xml(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + + // Verify special tags are used as-is + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains(""), "Should contain special start tag"); + assert!(lark_str.contains(""), "Should contain special end tag"); + } + + #[test] + fn test_tool_grammar_builder_build_xml_required_optional() { + // Test ToolGrammarBuilder.build_xml() with mix of required/optional params + let tools = vec![ + crate::tools::ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query", "string", "Search query", true) + .param("max_results", "integer", "Max results", false) + .build(), + ]; + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("") + .end_tag("") + .start_is_special(false) + .end_is_special(false) + .build_xml(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + + // Verify both params are present + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains("param_0_0:"), "Should have param_0_0 rule (query - required)"); + assert!(lark_str.contains("param_0_1:"), "Should have param_0_1 rule (max_results - optional)"); + assert!(lark_str.contains("parameter=query"), "Should contain query parameter tag"); + assert!(lark_str.contains("parameter=max_results"), "Should contain max_results parameter tag"); + } + + #[test] + fn test_tool_grammar_builder_build_xml_no_parameters() { + // Test ToolGrammarBuilder.build_xml() with tool that has no parameters + let tools = vec![ + crate::tools::ToolBuilder::new("hello".to_string(), "Say hello".to_string()) + .param("query", "string", "Search query", true) + .parameters_schema(serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + })) + .build(), + ]; + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("") + .end_tag("") + .start_is_special(false) + .end_is_special(false) + .build_xml(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + + // Verify tool with no parameters still generates valid grammar + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains("function=hello"), "Should contain function tag"); + } + + #[test] + fn test_tool_grammar_builder_build_json_no_parameters() { + // Test ToolGrammarBuilder.build_json() with tool that has no parameters + let tools = vec![ + crate::tools::ToolBuilder::new("hello".to_string(), "Say hello".to_string()) + .parameters_schema(serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + })) + .build(), + ]; + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("") + .end_tag("") + .start_is_special(false) + .end_is_special(false) + .build_json(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + + // Verify tool with no parameters still generates valid grammar + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains("obj_hello:"), "Should contain obj_hello rule"); + } + + #[test] + fn test_tool_grammar_builder_build_json_empty_tools() { + // Test ToolGrammarBuilder.build_json() with empty tools list + let grammar = ToolGrammarBuilder::new() + .tools(&[]) + .start_tag("") + .end_tag("") + .start_is_special(false) + .end_is_special(false) + .build_json(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + + // Verify grammar is still valid with no tools + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + // With no tools, obj should be a generic object + assert!(lark_str.contains("obj: %json"), "Should have obj rule with generic schema"); + } + + #[test] + fn test_tool_grammar_builder_build_xml_empty_tools() { + // Test ToolGrammarBuilder.build_xml() with empty tools list + let grammar = ToolGrammarBuilder::new() + .tools(&[]) + .start_tag("") + .end_tag("") + .start_is_special(false) + .end_is_special(false) + .build_xml(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + + // Verify grammar is still valid with no tools + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains("tool_content:"), "Should have tool_content rule"); + } + + #[test] + fn test_tool_grammar_builder_build_json_structure_validates() { + // Full end-to-end: verify JSON grammar produces valid llguidance TopLevelGrammar structure + let tools = vec![crate::tools::ToolBuilder::new("calculator".to_string(), "Calculator".to_string()) + .param("expression", "string", "Math expression", true) + .build()]; + + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("") + .end_tag("") + .start_is_special(false) + .end_is_special(false) + .build_json(); + + // Grammar should have at least one sub-grammar + assert!(grammar.grammars.len() > 0, "Should have generated grammars"); + } + + #[test] + fn test_tool_grammar_builder_build_xml_structure_validates() { + // Full end-to-end: verify XML grammar produces valid llguidance TopLevelGrammar structure + let tools = vec![crate::tools::ToolBuilder::new("formatter".to_string(), "Formatter".to_string()) + .param("text", "string", "Text to format", true) + .build()]; + + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("") + .end_tag("") + .start_is_special(false) + .end_is_special(false) + .build_xml(); + + // Grammar should have at least one sub-grammar + assert!(grammar.grammars.len() > 0, "Should have generated grammars"); + } + + // === Comprehensive ToolGrammarBuilder Tests === + + #[test] + fn test_tool_grammar_builder_build_xml_complex_full_schema() { + // Test ToolGrammarBuilder.build_xml() with complex nested schema + // and model-specific envelope tags with token IDs + let tools = vec![ + crate::tools::ToolBuilder::new("edit_file".to_string(), "Edit a file".to_string()) + .param("file_path", "string", "Path to the file", true) + .param("old_string", "string", "String to replace", true) + .param("new_string", "string", "Replacement string", true) + .param("max_replacements", "integer", "Maximum replacements", false) + .param("context", "object", "Context object", false) + .param("tags", "array", "Optional tags array", false) + .build(), + ]; + + // Build XML grammar with token IDs for envelope tags + let mut start_ids = HashSet::new(); + start_ids.insert(151657); + let mut end_ids = HashSet::new(); + end_ids.insert(151658); + + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("") + .end_tag("") + .start_is_special(false) + .end_is_special(false) + .start_token_ids(Some(start_ids)) + .end_token_ids(Some(end_ids)) + .build_xml(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + // println!("{}", &lark_str); + + // Validate envelope layer (token IDs) + assert!(lark_str.contains("<[151657]>"), "Should have start token ID envelope"); + assert!(lark_str.contains("<[151658]>"), "Should have end token ID envelope"); + + // Validate tool_call structure + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains("tool_call:"), "Should have tool_call rule"); + + // Validate tool_content alternation + assert!(lark_str.contains("tool_content: tool_0"), "Should have tool_content with tool_0"); + + // Validate function tag layer + assert!(lark_str.contains("function=edit_file"), "Should have function tag"); + assert!(lark_str.contains("function="), "Should have function tag pattern"); + + // Validate parameter tags and rules + assert!(lark_str.contains("parameter=file_path"), "Should have file_path parameter tag"); + assert!(lark_str.contains("parameter=old_string"), "Should have old_string parameter tag"); + assert!(lark_str.contains("parameter=new_string"), "Should have new_string parameter tag"); + assert!(lark_str.contains("parameter=max_replacements"), "Should have max_replacements parameter tag"); + assert!(lark_str.contains("parameter=context"), "Should have context parameter tag"); + assert!(lark_str.contains("parameter=tags"), "Should have tags parameter tag"); + + // Validate param rules with correct types + assert!(lark_str.contains("param_0_0:"), "Should have param_0_0 rule (file_path - required)"); + assert!(lark_str.contains("param_0_1:"), "Should have param_0_1 rule (old_string - required)"); + assert!(lark_str.contains("param_0_2:"), "Should have param_0_2 rule (new_string - required)"); + assert!(lark_str.contains("param_0_3:"), "Should have param_0_3 rule (max_replacements - optional)"); + assert!(lark_str.contains("param_0_4:"), "Should have param_0_4 rule (context - optional)"); + assert!(lark_str.contains("param_0_5:"), "Should have param_0_5 rule (tags - optional)"); + + // Validate value rules with regex patterns for each type + assert!(lark_str.contains("value_0_0:"), "Should have value_0_0 rule for file_path"); + assert!(lark_str.contains("value_0_1:"), "Should have value_0_1 rule for old_string"); + assert!(lark_str.contains("value_0_2:"), "Should have value_0_2 rule for new_string"); + assert!(lark_str.contains("value_0_3:"), "Should have value_0_3 rule for max_replacements"); + assert!(lark_str.contains("value_0_4:"), "Should have value_0_4 rule for context"); + assert!(lark_str.contains("value_0_5:"), "Should have value_0_5 rule for tags"); + + // Validate required params are bare (no ? wrapper) + assert!(lark_str.contains("param_0_0 "), "file_path should be bare (required)"); + assert!(lark_str.contains("param_0_1 "), "old_string should be bare (required)"); + assert!(lark_str.contains("param_0_2 "), "new_string should be bare (required)"); + + // Validate optional params have ? wrapper + assert!(lark_str.contains("(param_0_3)?"), "max_replacements should be optional"); + assert!(lark_str.contains("(param_0_4)?"), "context should be optional"); + assert!(lark_str.contains("(param_0_5)?"), "tags should be optional"); + + // Validate tool rule structure + assert!(lark_str.contains("tool_0:"), "Should have tool_0 rule"); + assert!(lark_str.contains("tool_0: \"\""), "Should have tool_0 with function tags"); + } + + #[test] + fn test_tool_grammar_builder_build_json_complex_full_schema() { + // Test ToolGrammarBuilder.build_json() with complex nested schema + // and model-specific envelope tags with token IDs + let tools = vec![ + crate::tools::ToolBuilder::new("edit_file".to_string(), "Edit a file".to_string()) + .param("file_path", "string", "Path to the file", true) + .param("old_string", "string", "String to replace", true) + .param("new_string", "string", "Replacement string", true) + .param("max_replacements", "integer", "Maximum replacements", false) + .param("context", "object", "Context object", false) + .param("tags", "array", "Optional tags array", false) + .build(), + ]; + + // Build JSON grammar with token IDs for envelope tags + let mut start_ids = HashSet::new(); + start_ids.insert(151657); + let mut end_ids = HashSet::new(); + end_ids.insert(151658); + + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("") + .end_tag("") + .start_is_special(false) + .end_is_special(false) + .start_token_ids(Some(start_ids)) + .end_token_ids(Some(end_ids)) + .build_json(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + + // Validate envelope layer (token IDs) + assert!(lark_str.contains("<[151657]>"), "Should have start token ID envelope"); + assert!(lark_str.contains("<[151658]>"), "Should have end token ID envelope"); + + // Validate tool_call structure + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains("tool_call:"), "Should have tool_call rule"); + + // Validate tool_obj structure with name and arguments + assert!(lark_str.contains("tool_obj:"), "Should have tool_obj rule"); + assert!(lark_str.contains("\"name\""), "Should have name in tool_obj"); + assert!(lark_str.contains("\"arguments\""), "Should have arguments in tool_obj"); + + // Validate obj rule references the tool + assert!(lark_str.contains("obj_edit_file:"), "Should have obj_edit_file rule"); + assert!(lark_str.contains("obj: obj_edit_file"), "Should have obj alternation"); + + // Validate JSON schema contains all parameters + assert!(lark_str.contains("file_path"), "Should contain file_path in schema"); + assert!(lark_str.contains("old_string"), "Should contain old_string in schema"); + assert!(lark_str.contains("new_string"), "Should contain new_string in schema"); + assert!(lark_str.contains("max_replacements"), "Should contain max_replacements in schema"); + assert!(lark_str.contains("context"), "Should contain context in schema"); + assert!(lark_str.contains("tags"), "Should contain tags in schema"); + + // Validate required parameters in JSON schema + assert!(lark_str.contains("\"required\""), "Should have required array"); + assert!(lark_str.contains("file_path"), "Should have file_path in required"); + assert!(lark_str.contains("old_string"), "Should have old_string in required"); + assert!(lark_str.contains("new_string"), "Should have new_string in required"); + } + + #[test] + fn test_tool_grammar_builder_build_xml_multiple_tools_full_validation() { + // Test ToolGrammarBuilder.build_xml() with multiple tools and full validation + let tools = vec![ + crate::tools::ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query", "string", "Search query", true) + .param("max_results", "integer", "Max results", false) + .build(), + crate::tools::ToolBuilder::new("weather".to_string(), "Get weather".to_string()) + .param("city", "string", "City name", true) + .param("units", "string", "Units", false) + .build(), + ]; + + // Build XML grammar with token IDs for envelope tags + let mut start_ids = HashSet::new(); + start_ids.insert(151657); + let mut end_ids = HashSet::new(); + end_ids.insert(151658); + + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("") + .end_tag("") + .start_is_special(false) + .end_is_special(false) + .start_token_ids(Some(start_ids)) + .end_token_ids(Some(end_ids)) + .build_xml(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + + // Validate envelope layer + assert!(lark_str.contains("<[151657]>"), "Should have start token ID envelope"); + assert!(lark_str.contains("<[151658]>"), "Should have end token ID envelope"); + + // Validate tool_content alternation with both tools + assert!(lark_str.contains("tool_content: tool_0 | tool_1"), "Should have tool alternation"); + + // Validate tool_0 (search) structure + assert!(lark_str.contains("tool_0:"), "Should have tool_0 rule"); + assert!(lark_str.contains("function=search"), "Should have search function tag"); + assert!(lark_str.contains("parameter=query"), "Should have query parameter"); + assert!(lark_str.contains("parameter=max_results"), "Should have max_results parameter"); + assert!(lark_str.contains("param_0_0:"), "Should have param_0_0 (query - required)"); + assert!(lark_str.contains("param_0_1:"), "Should have param_0_1 (max_results - optional)"); + + // Validate tool_1 (weather) structure + assert!(lark_str.contains("tool_1:"), "Should have tool_1 rule"); + assert!(lark_str.contains("function=weather"), "Should have weather function tag"); + assert!(lark_str.contains("parameter=city"), "Should have city parameter"); + assert!(lark_str.contains("parameter=units"), "Should have units parameter"); + assert!(lark_str.contains("param_1_0:"), "Should have param_1_0 (city - required)"); + assert!(lark_str.contains("param_1_1:"), "Should have param_1_1 (units - optional)"); + } + + #[test] + fn test_tool_grammar_builder_build_json_multiple_tools_full_validation() { + // Test ToolGrammarBuilder.build_json() with multiple tools and full validation + let tools = vec![ + crate::tools::ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query", "string", "Search query", true) + .param("max_results", "integer", "Max results", false) + .build(), + crate::tools::ToolBuilder::new("weather".to_string(), "Get weather".to_string()) + .param("city", "string", "City name", true) + .param("units", "string", "Units", false) + .build(), + ]; + + // Build JSON grammar with token IDs for envelope tags + let mut start_ids = HashSet::new(); + start_ids.insert(151657); + let mut end_ids = HashSet::new(); + end_ids.insert(151658); + + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("") + .end_tag("") + .start_is_special(false) + .end_is_special(false) + .start_token_ids(Some(start_ids)) + .end_token_ids(Some(end_ids)) + .build_json(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + + // Validate envelope layer + assert!(lark_str.contains("<[151657]>"), "Should have start token ID envelope"); + assert!(lark_str.contains("<[151658]>"), "Should have end token ID envelope"); + + // Validate obj alternation with both tools + assert!(lark_str.contains("obj: obj_search | obj_weather"), "Should have obj alternation"); + + // Validate obj_search structure + assert!(lark_str.contains("obj_search:"), "Should have obj_search rule"); + assert!(lark_str.contains("query"), "Should have query in obj_search"); + assert!(lark_str.contains("max_results"), "Should have max_results in obj_search"); + + // Validate obj_weather structure + assert!(lark_str.contains("obj_weather:"), "Should have obj_weather rule"); + assert!(lark_str.contains("city"), "Should have city in obj_weather"); + assert!(lark_str.contains("units"), "Should have units in obj_weather"); + + // Validate required parameters in both schemas + assert!(lark_str.contains("\"required\":[\"query\"]"), "Should have query in required for search"); + assert!(lark_str.contains("\"required\":[\"city\"]"), "Should have city in required for weather"); + } +} diff --git a/src/transfer/comm.rs b/src/transfer/comm.rs index c2c83b39..dbc534ae 100644 --- a/src/transfer/comm.rs +++ b/src/transfer/comm.rs @@ -1,6 +1,6 @@ // src/core/transfer/comm.rs use super::{FinishedPrefillData, PdConfig, PdRole, TransferMessage}; -use bincode; +use rmp_serde; use candle_core::Result; use interprocess::local_socket::traits::Listener; use interprocess::local_socket::traits::Stream; @@ -382,9 +382,9 @@ impl Communicator { } /// Generic, standardized function to send a message. -/// Uses a 4-byte LE length prefix followed by bincode data. +/// Uses a 4-byte LE length prefix followed by rmp data. fn send_message_generic(stream: &mut (impl Read + Write), msg: &TransferMessage) -> Result { - let serialized: Vec = bincode::serialize(msg).map_err(candle_core::Error::wrap)?; + let serialized: Vec = rmp_serde::to_vec(msg).map_err(candle_core::Error::wrap)?; let len = serialized.len() as u32; stream.write_all(&len.to_le_bytes())?; stream.write_all(&serialized)?; @@ -393,7 +393,7 @@ fn send_message_generic(stream: &mut (impl Read + Write), msg: &TransferMessage) } /// Generic, standardized function to receive a message. -/// Reads a 4-byte LE length prefix then bincode data. +/// Reads a 4-byte LE length prefix then rmp data. fn receive_message_generic(stream: &mut (impl Read + Write)) -> Result { let mut len_buf = [0u8; 4]; stream.read_exact(&mut len_buf)?; @@ -407,6 +407,6 @@ fn receive_message_generic(stream: &mut (impl Read + Write)) -> Result Vec { - let mut tokens = tokenizer - .get_added_tokens_decoder() - .into_values() - .filter(|added| added.special) - .map(|added| added.content) + let special_tokens = SpecialTokens::new(tokenizer); + let mut tokens = special_tokens + .all_special() + .into_iter() + .map(|t| t.string()) .collect::>(); for marker in tool_markers { @@ -226,6 +227,19 @@ impl ChatTemplate { escape_special_tokens_in_text(content, &self.escape_tokens, &self.preserve_tokens) } + pub fn supports_tools(&self) -> bool { + let Some(template) = &self.chat_template else { + return false; + }; + let lower = template.to_lowercase(); + lower.contains("tools") + || lower.contains("tool_calls") + || lower.contains("[available_tools]") + || lower.contains("") + } + #[allow(dead_code)] fn clear_message(&mut self) { self.messages.clear() diff --git a/src/utils/command.rs b/src/utils/command.rs index b8cbc74f..1951eadf 100644 --- a/src/utils/command.rs +++ b/src/utils/command.rs @@ -1,5 +1,5 @@ use crate::runner::MessageType; -use bincode; +use rmp_serde; use interprocess::local_socket::traits::{Listener, Stream}; use interprocess::local_socket::{GenericNamespaced, Name, ToNsName}; use interprocess::local_socket::{ListenerOptions, Stream as LocalStream}; @@ -41,7 +41,7 @@ impl CommandManager { streams: &mut Vec, message: &MessageType, ) -> std::io::Result<()> { - let serialized = bincode::serialize(message).expect("Serialization failed"); + let serialized = rmp_serde::to_vec(message).expect("Serialization failed"); for stream in streams.iter_mut() { stream.write_all(&(serialized.len() as u32).to_le_bytes())?; stream.write_all(&serialized)?; @@ -74,7 +74,7 @@ impl CommandManager { let mut serialized = vec![0u8; length]; stream.read_exact(&mut serialized)?; let message: MessageType = - bincode::deserialize(&serialized).expect("Deserialization failed"); + rmp_serde::from_slice(&serialized).expect("Deserialization failed"); // Send acknowledgment stream.write_all(&[1])?; stream.flush()?; diff --git a/src/utils/config.rs b/src/utils/config.rs index d488838a..630dd997 100644 --- a/src/utils/config.rs +++ b/src/utils/config.rs @@ -1,127 +1,30 @@ // src/utils/config.rs use crate::transfer::PdConfig; +use llguidance::api::TopLevelGrammar; #[cfg(feature = "python")] use pyo3::pyclass; -use serde::de::value::SeqAccessDeserializer; -use serde::de::{Deserializer, Visitor}; -use serde::{Deserialize, Serialize, Serializer}; +use serde::{Deserialize, Serialize}; +use serde::de::Error; use std::collections::HashMap; -use std::fmt; -#[derive(Debug, Clone)] -pub enum EosTokenId { - Single(u32), - Multiple(Vec), -} - -impl<'de> Deserialize<'de> for EosTokenId { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - if deserializer.is_human_readable() { - // For JSON: deserialize as "untagged" using a visitor - struct EosTokenIdVisitor; - - impl<'de> Visitor<'de> for EosTokenIdVisitor { - type Value = EosTokenId; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a u32 or a sequence of u32s") - } - - // Handle a single number - fn visit_u64(self, v: u64) -> Result { - Ok(EosTokenId::Single(v as u32)) - } - - // Handle an array of numbers - fn visit_seq(self, seq: A) -> Result - where - A: serde::de::SeqAccess<'de>, - { - let vals = Vec::::deserialize(SeqAccessDeserializer::new(seq))?; - Ok(EosTokenId::Multiple(vals)) - } - } - - deserializer.deserialize_any(EosTokenIdVisitor) - } else { - // For Bincode: deserialize as "tagged" - let bincode_id = BincodeEosTokenId::deserialize(deserializer)?; - let id = match bincode_id { - BincodeEosTokenId::Single(v) => EosTokenId::Single(v), - BincodeEosTokenId::Multiple(v) => EosTokenId::Multiple(v), - }; - Ok(id) - } - } -} - -impl Serialize for EosTokenId { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - if serializer.is_human_readable() { - // For JSON: serialize as "untagged" - match self { - EosTokenId::Single(v) => v.serialize(serializer), - EosTokenId::Multiple(v) => v.serialize(serializer), - } - } else { - // For Bincode: serialize as "tagged" - let bincode_id = match self { - EosTokenId::Single(v) => BincodeEosTokenId::Single(*v), - EosTokenId::Multiple(v) => BincodeEosTokenId::Multiple(v.clone()), - }; - bincode_id.serialize(serializer) - } +#[cfg(not(feature = "python"))] +impl SamplingParams { + /// Convert grammar to constraint for GuidanceState construction + /// Prioritizes constraint field, falls back to grammar field + pub fn to_constraint(&self) -> Option { + self.grammar.clone() } } -impl EosTokenId { - /// Merge `other` into `self`, returning the combined token set. - /// - Single + Single => Multiple([a, b]) - /// - Single + Multiple => Multiple([a, ...]) - /// - Multiple + Single => Multiple([... , b]) - /// - Multiple + Multiple => Multiple([... , ...]) - pub fn merge(self, other: EosTokenId) -> EosTokenId { - let mut out = self.into_vec(); - out.extend(other.into_vec()); - EosTokenId::Multiple(out) - } - - /// Like merge, but de-duplicates while preserving first-seen order. - pub fn merge_dedup(self, other: EosTokenId) -> EosTokenId { - use std::collections::HashSet; - - let mut seen = HashSet::::new(); - let mut out = Vec::::new(); - - for id in self.into_vec().into_iter().chain(other.into_vec()) { - if seen.insert(id) { - out.push(id); - } - } - EosTokenId::Multiple(out) - } - - fn into_vec(self) -> Vec { - match self { - EosTokenId::Single(x) => vec![x], - EosTokenId::Multiple(v) => v, - } +#[cfg(feature = "python")] +impl SamplingParams { + /// Convert grammar to constraint for GuidanceState construction + pub fn to_constraint(&self) -> Option { + self.grammar.clone() } } -// To make the "tagged" logic work for bincode, we need a separate -// definition of the enum with derived traits. We keep it private inside this module. -#[derive(Serialize, Deserialize)] -enum BincodeEosTokenId { - Single(u32), - Multiple(Vec), -} +// EosTokenId enum has been replaced with direct Vec for simplicity #[derive(Serialize, Deserialize, Debug, Clone)] pub struct MoEConfig { @@ -186,7 +89,8 @@ pub struct Config { pub final_logit_softcapping: Option, pub tie_word_embeddings: Option, pub bos_token_id: Option, - pub eos_token_id: Option, + #[serde(deserialize_with = "deserialize_eos_token_id")] + pub eos_token_id: Option>, pub use_sliding_window: Option, pub sliding_window: Option, pub max_window_layers: Option, @@ -217,10 +121,51 @@ impl Config { (None, None) => None, (None, Some(e)) => Some(e.clone()), (Some(e), None) => Some(e), - (Some(e), Some(other)) => Some(e.merge(other.clone())), + (Some(e), Some(other)) => { + let mut merged = e.clone(); + merged.extend(other.clone()); + Some(merged) + } }; } } + +// Custom deserializer for eos_token_id to handle both integer and array formats +fn deserialize_eos_token_id<'de, D>(deserializer: D) -> Result>, D::Error> +where + D: serde::Deserializer<'de>, +{ + use serde::Deserialize; + match Option::::deserialize(deserializer)? { + Some(serde_json::Value::Number(n)) => { + if let Some(id) = n.as_u64() { + Ok(Some(vec![id as u32])) + } else { + Err(serde::de::Error::custom("eos_token_id must be a positive integer")) + } + } + Some(serde_json::Value::Array(arr)) => { + let ids: Result, D::Error> = arr + .into_iter() + .map(|v| { + if let Some(id) = v.as_u64() { + Ok(id as u32) + } else { + Err(D::Error::custom("eos_token_id array must contain only unsigned integers")) + } + }) + .collect(); + Ok(Some(ids?)) + } + Some(serde_json::Value::Null) => Ok(None), + Some(v) => Err(serde::de::Error::custom(format!( + "Expected integer or array for eos_token_id, got {:?}", + v + ))), + None => Ok(None), + } +} + #[cfg(not(feature = "python"))] #[derive(Serialize, Deserialize, Clone, Debug)] pub struct EngineConfig { @@ -263,6 +208,10 @@ pub struct EngineConfig { pub tool_prompt_template: Option, pub pd_server_prefix_cache_ratio: Option, pub pd_client_prefix_cache_ratio: Option, + /// Allow client-submitted constraints via HTTP API + pub allow_constraint_api: bool, + /// Whether to automatically build LLG grammar from tools + pub enable_tool_grammar: bool, } #[cfg(feature = "python")] @@ -340,6 +289,10 @@ pub struct EngineConfig { pub pd_server_prefix_cache_ratio: Option, #[pyo3(get, set)] pub pd_client_prefix_cache_ratio: Option, + #[pyo3(get, set)] + pub allow_constraint_api: bool, + #[pyo3(get, set)] + pub enable_tool_grammar: bool, } #[cfg(not(feature = "python"))] @@ -374,7 +327,9 @@ impl EngineConfig { tool_prompt_template: Option, pd_server_prefix_cache_ratio: Option, pd_client_prefix_cache_ratio: Option, - ) -> Self { + allow_constraint_api: bool, + enable_tool_grammar: bool, + ) -> Self { let mut device_ids = device_ids.unwrap_or_default(); if device_ids.is_empty() { device_ids.push(0); @@ -420,12 +375,14 @@ impl EngineConfig { pd_config, mcp_command, mcp_config, - mcp_args, - tool_prompt_template, - pd_server_prefix_cache_ratio, - pd_client_prefix_cache_ratio, - } - } + mcp_args, + tool_prompt_template, + pd_server_prefix_cache_ratio, + pd_client_prefix_cache_ratio, + allow_constraint_api, + enable_tool_grammar, + } + } } #[derive(Clone, Debug, serde::Deserialize)] @@ -451,14 +408,16 @@ pub struct SamplingParams { pub presence_penalty: Option, #[serde(default)] pub stop_sequences: Option>, - #[serde(skip)] - pub stop_token_ids: Option>>, + // stop_token_ids removed - use SpecialTokens for stop detection #[serde(alias = "enable_thinking")] pub thinking: Option, // enable reasoning /// Tool mode for tool call handling. /// If Some(true), external tools are enabled and stream finishes at . #[serde(default)] pub mcp_mode: Option, + /// Grammar constraint as TopLevelGrammar for RPC serialization + #[serde(default)] + pub grammar: Option, } #[cfg(feature = "python")] @@ -484,15 +443,19 @@ pub struct SamplingParams { #[pyo3(get, set)] #[serde(default)] pub stop_sequences: Option>, - #[serde(skip)] - pub stop_token_ids: Option>>, + // stop_token_ids removed - use SpecialTokens for stop detection /// Tool mode for tool call handling. /// If Some(true), external tools are enabled and stream finishes at . #[pyo3(get, set)] pub mcp_mode: Option, #[pyo3(get, set)] - #[serde(alias = "enable_thinking")] pub thinking: Option, + /// Grammar constraint as TopLevelGrammar for RPC serialization + #[serde(default)] + pub grammar: Option, + /// Grammar constraint as JSON string for Python API + #[pyo3(get, set)] + pub grammar_json: Option, } #[cfg(not(feature = "python"))] @@ -519,7 +482,7 @@ impl SamplingParams { presence_penalty, mcp_mode: None, stop_sequences: None, - stop_token_ids: None, + grammar: None, thinking, } } @@ -536,12 +499,13 @@ impl SamplingParams { presence_penalty: None, mcp_mode: None, stop_sequences: None, - stop_token_ids: None, + grammar: None, thinking: None, } } } +#[cfg(not(feature = "python"))] impl Default for SamplingParams { fn default() -> Self { Self { @@ -555,12 +519,33 @@ impl Default for SamplingParams { presence_penalty: None, mcp_mode: None, stop_sequences: None, - stop_token_ids: None, + grammar: None, thinking: None, } } } +#[cfg(feature = "python")] +impl Default for SamplingParams { + fn default() -> Self { + Self { + temperature: None, + max_tokens: Some(16384), + ignore_eos: false, + top_k: None, + top_p: None, + session_id: None, + frequency_penalty: None, + presence_penalty: None, + mcp_mode: None, + stop_sequences: None, + thinking: None, + grammar: None, + grammar_json: None, + } + } +} + #[derive(Serialize, Deserialize, Debug, Clone)] pub enum ModelType { Qwen3, @@ -588,8 +573,8 @@ pub struct GenerationConfig { /// Randomness of sampling. /// rec. default = 1 pub temperature: Option, - /// Cumulative prob of the top tokens to consider, must be in (0, 1]. Set 1 to consider all toks. - /// rec. default = 1 + /// Cumulative prob of the top tokens to consider, must be in (0, 1]. Set 1 to consider all toks. + /// rec. default = 1 pub top_p: Option, /// Control the number of top tokens to consider, set -1 to consider all. /// rec. default = -1 @@ -599,7 +584,7 @@ pub struct GenerationConfig { pub presence_penalty: Option, pub bos_token_id: Option, - pub eos_token_id: Option, + pub eos_token_id: Option>, } #[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] diff --git a/src/utils/guidance.rs b/src/utils/guidance.rs index e37dbad8..15e3727f 100644 --- a/src/utils/guidance.rs +++ b/src/utils/guidance.rs @@ -1,54 +1,1291 @@ // src/utils/guidance.rs -//! Guided decoding support via llguidance. -//! -//! NOTE: This module is currently stubbed out due to API changes in llguidance >= 0.6. -//! The TopLevelGrammar::from_json_schema method is no longer available. -//! Guided decoding features are temporarily disabled. - -use serde_json::Value; -use std::path::Path; +use anyhow::Result; +use candle_core::Tensor; +use llguidance::{api::TopLevelGrammar, Matcher, ParserFactory as LlgParserFactory}; +use std::collections::HashMap; use std::sync::Arc; +use tokenizers::Tokenizer; +use crate::utils::special_tokens::SpecialTokens; +use toktrie::{SimpleVob, TokTrie}; +use toktrie_hf_tokenizers::{ByteTokenizer, ByteTokenizerEnv}; -// Import toktrie from the crate root (it's re-exported by llguidance) -pub use toktrie::TokTrie; +use crate::tools::Tool; +use crate::utils::logits_processor::{LogitsProcessor, Sampling}; +use serde_json::json; +use crate::tools::schema::ToolGrammarBuilder; + +/// Error type for grammar-related errors +#[derive(Debug, thiserror::Error)] +pub enum GrammarError { + #[error("structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag")] + TooManyConstraints, + + #[error("response_format.json_schema is required for type=json_schema")] + MissingJsonSchema, + + #[error("unsupported response_format type: {0}")] + UnsupportedFormat(String), + + #[error("invalid grammar: {0}")] + InvalidGrammar(String), + + #[error("tool grammar construction failed: {0}")] + ToolGrammarError(String), +} + +pub type GrammarResult = Result; + +/// Builder for structured output constraint grammars +pub struct ConstraintBuilder { + choice: Option>, + regex: Option, + json: Option, + grammar: Option, + structural_tag: Option, +} + +impl ConstraintBuilder { + pub fn new() -> Self { + Self { + choice: None, + regex: None, + json: None, + grammar: None, + structural_tag: None, + } + } + + pub fn choice(mut self, choice: Vec) -> Self { + self.choice = Some(choice); + self + } + + pub fn regex(mut self, regex: String) -> Self { + self.regex = Some(regex); + self + } + + pub fn json(mut self, json: serde_json::Value) -> Self { + self.json = Some(json); + self + } + + pub fn grammar(mut self, grammar: String) -> Self { + self.grammar = Some(grammar); + self + } + + pub fn structural_tag(mut self, tag: serde_json::Value) -> Self { + self.structural_tag = Some(tag); + self + } + + pub fn build(self) -> Result> { + let mut selected: Option = None; + let mut constraint_count = 0; + + if let Some(choice) = self.choice { + constraint_count += 1; + if constraint_count > 1 { + return Err(anyhow::Error::msg("structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag")); + } + let choice_gram = crate::tools::schema::build_choice_lark_grammar(&choice) + .map_err(|e| anyhow::Error::msg(e))?; + selected = Some(choice_gram); + } + + if let Some(regex) = self.regex { + constraint_count += 1; + if constraint_count > 1 { + return Err(anyhow::Error::msg("structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag")); + } + let regex_gram = TopLevelGrammarExt::from_regex_ascii(®ex); + selected = Some(regex_gram); + } + + if let Some(schema) = self.json { + constraint_count += 1; + if constraint_count > 1 { + return Err(anyhow::Error::msg("structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag")); + } + let schema = crate::tools::schema::sanitize_schema_for_llguidance(&schema); + let json_gram = TopLevelGrammarExt::from_json_schema_utf8(schema) + .map_err(|e| anyhow::Error::msg(e.to_string()))?; + selected = Some(json_gram); + } + + if let Some(grammar) = self.grammar { + constraint_count += 1; + if constraint_count > 1 { + return Err(anyhow::Error::msg("structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag")); + } + let lark_gram = TopLevelGrammarExt::from_lark_utf8(&grammar); + selected = Some(lark_gram); + } + + if let Some(tag) = self.structural_tag { + constraint_count += 1; + if constraint_count > 1 { + return Err(anyhow::Error::msg("structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag")); + } + let (start, end, schema) = crate::tools::schema::parse_structural_tag(&tag) + .map_err(|e| anyhow::Error::msg(e))?; + let schema = crate::tools::schema::sanitize_schema_for_llguidance(&schema); + let tools = crate::tools::schema::schema_to_tools(&schema); + let tool_gram = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag(&start) + .end_tag(&end) + .start_is_special(false) + .end_is_special(false) + .build_json(); + selected = Some(tool_gram); + } + + if selected.is_none() { + return Err(anyhow::Error::msg("structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag")); + } + + Ok(selected) + } +} + +/// Builder for composing multiple grammars with alternation +/// This provides a more readable, declarative way to build composed grammars +pub struct GrammarBuilder { + alternatives: Vec, + max_tokens: Option, +} + +impl GrammarBuilder { + pub fn new() -> Self { + Self { + alternatives: Vec::new(), + max_tokens: None, + } + } + + pub fn alternative(mut self, grammar: TopLevelGrammar) -> Self { + self.alternatives.push(grammar); + self + } + + pub fn max_tokens(mut self, tokens: usize) -> Self { + self.max_tokens = Some(tokens); + self + } + + pub fn build(self) -> TopLevelGrammar { + // Note: GrammarBuilder currently uses chat_text_expression() without EOS tokens + // EOS token support is provided through compose_grammars() directly + match self.alternatives.len() { + 0 => { + let lark = chat_text_expression(); + TopLevelGrammar::from_lark_utf8(&lark) + } + 1 => { + let mut gram = self.alternatives.into_iter().next().unwrap(); + gram.max_tokens = self.max_tokens; + gram + } + _ => { + let merged = merge_top_level_grammars( + self.alternatives, + self.max_tokens, + Some("|".to_string()) + ); + merged + } + } + } +} + +/// Extension trait for TopLevelGrammar with built-in sanitization +/// This ensures all grammar construction paths sanitize inputs consistently +pub trait TopLevelGrammarExt: Sized { + /// Create TopLevelGrammar from regex with ASCII sanitization + fn from_regex_ascii(regex: &str) -> Self; + + /// Create TopLevelGrammar from Lark string with UTF-8 sanitization + fn from_lark_utf8(lark: &str) -> Self; + + /// Create TopLevelGrammar from JSON schema with UTF-8 sanitization + fn from_json_schema_utf8(schema: serde_json::Value) -> Result; +} + +impl TopLevelGrammarExt for TopLevelGrammar { + fn from_regex_ascii(regex: &str) -> Self { + let sanitized = sanitize_to_ascii(regex); + Self::from_regex(&sanitized) + } + + fn from_lark_utf8(lark: &str) -> Self { + let sanitized = sanitize_utf8_valid(lark); + Self::from_lark(sanitized) + } + + fn from_json_schema_utf8(schema: serde_json::Value) -> Result { + let schema_str = serde_json::to_string(&schema)?; + let sanitized = sanitize_utf8_valid(&schema_str); + let val = serde_json::from_str(&sanitized)?; + Ok(Self::from_json_schema(val)) + } +} + +/// Sanitize a string by removing non-ASCII bytes +/// This is used for tool choice strings to ensure only safe ASCII characters reach llguidance lexer +pub fn sanitize_to_ascii(s: &str) -> String { + s.bytes() + .filter(|&b| b.is_ascii()) + .map(|b| b as char) + .collect::() +} + +/// Sanitize a string by removing invalid UTF-8 sequences and control characters +pub fn sanitize_utf8_valid(s: &str) -> String { + let mut result = String::new(); + for ch in s.chars() { + if ch.is_control() && !matches!(ch, '\n' | '\r' | '\t') { + continue; + } + result.push(ch); + } + result +} + +/// Parse a Lark grammar string to extract the start rule RHS and other rules +/// Returns (start_rhs, other_rules) where start_rhs is the RHS of the start: rule +/// The RHS should be a list of rule names separated by | for alternation +fn parse_lark_grammar(lark: &str) -> (String, Vec) { + let lines: Vec<&str> = lark.lines().collect(); + if lines.is_empty() { + return (String::new(), Vec::new()); + } + + let first_line = lines[0].trim(); + if first_line.starts_with("start:") { + // Extract only the rule names after "start:", not the full rule definition + let rhs_part = first_line.strip_prefix("start:").unwrap_or("").trim(); + + // Parse the RHS to get individual rule names (separated by |) + // We only want the rule names, not their definitions + let rule_names: Vec = rhs_part + .split('|') + .map(|s| s.trim().to_string()) + .collect(); + + // The RHS for alternation should be just the rule names + let start_rhs = rule_names.join(" | "); + + // Return all remaining lines as other rules + let other_rules: Vec = lines[1..].iter().map(|s| s.to_string()).collect(); + + (start_rhs, other_rules) + } else { + // No start rule - treat entire grammar as the start rule + (lark.to_string(), Vec::new()) + } +} + +/// Combine grammar rules, handling duplicate rule names by merging them +fn combine_rules(rules: Vec) -> String { + if rules.is_empty() { + return String::new(); + } + + // Group rules by their name (the part before ":") + use std::collections::HashMap; + let mut rule_groups: HashMap> = HashMap::new(); + + for rule in rules { + let rule = rule.trim(); + if rule.is_empty() { + continue; + } + + // Find the rule name (before the first ":") + if let Some(colon_pos) = rule.find(':') { + let name = rule[..colon_pos].trim().to_string(); + let body = rule[colon_pos + 1..].trim().to_string(); + + rule_groups.entry(name).or_default().push(body); + } else { + // Rule without colon - add as-is + rule_groups.entry("anonymous".to_string()).or_default().push(rule.to_string()); + } + } + + // Reconstruct rules, merging duplicates + let mut combined = Vec::new(); + for (name, bodies) in rule_groups { + if bodies.len() == 1 { + combined.push(format!("{}: {}", name, bodies[0])); + } else { + // Multiple definitions for same rule - combine with alternation + combined.push(format!("{}: {}", name, bodies.join(" | "))); + } + } + + combined.join("\n") +} + +/// Merge multiple TopLevelGrammar objects into one +/// This creates a single Lark grammar with alternation at the start rule level +/// Each sub-grammar's rules are combined directly without rule_N indirection +pub fn merge_top_level_grammars(grammars: Vec, max_tokens: Option, start_separator: Option) -> TopLevelGrammar { + // Extract all Lark grammar strings + let mut lark_parts = Vec::new(); + + let sep = match start_separator { + Some(s) => s, + None => "|".to_string(), + }; + + for (_i, g) in grammars.iter().enumerate() { + for gw in &g.grammars { + if let Some(lark) = &gw.lark_grammar { + lark_parts.push(lark.clone()); + } + } + } + + if lark_parts.is_empty() { + let lark_start_exp = format!("start: text\ntext[stop=\"\"]: /((?s).*?)/"); + let mut tlg = TopLevelGrammar::from_lark(lark_start_exp); + tlg.max_tokens = max_tokens; + return tlg; + } + + // Parse each grammar and extract start RHS + other rules + let mut combined_start_rhs = Vec::new(); + let mut all_other_rules = Vec::new(); + + for lark in lark_parts.iter() { + crate::log_debug!("[llg] parse_lark_grammar() input: {}", &lark); + let (start_rhs, other_rules) = parse_lark_grammar(lark); + crate::log_debug!( + "[llg] parse_lark_grammar() -> start_rhs='{}', other_rules_count={}", + start_rhs, other_rules.len() + ); + combined_start_rhs.push(start_rhs); + all_other_rules.extend(other_rules); + } + + // Combine all other rules, handling duplicates + let combined_rules = combine_rules(all_other_rules); + + // Build new grammar with direct alternation at start + let start_separator = format!(" {} ", &sep); + let start_alternation = combined_start_rhs.join(&start_separator); + let final_grammar = format!("start: ( {} )+\n{}", start_alternation, combined_rules); + + let mut top_gram = TopLevelGrammar::from_lark(final_grammar); + top_gram.max_tokens = max_tokens; + top_gram +} + +/// Extract the Lark grammar string from TopLevelGrammar for debugging +pub fn get_lark_from_top_level_grammar(gram: &TopLevelGrammar) -> String { + if gram.grammars.is_empty() { + return "No grammars".to_string(); + } + let larks: Vec = gram.grammars.iter() + .filter_map(|g| g.lark_grammar.as_ref()) + .map(|s| s.clone()) + .collect(); + if larks.is_empty() { + format!("{} grammars, none have lark_grammar", gram.grammars.len()) + } else { + larks.join("\n---\n") + } +} + +/// Lark grammar TEXT pattern for common UTF-8 printable characters +/// Excludes control characters (0x00-0x1F), DEL (0x7F), and C1 controls (0x80-0x9F) +/// This pattern allows: +/// - ASCII printable: space (0x20) through tilde (0x7E) +/// - Unicode text: 0x80 onwards (Latin extended, accented chars, CJK, emoji, etc.) +/// - Common whitespace: newline, carriage return, tab +/// +/// ## Binary Token Matching with llguidance Matcher +/// +/// When working with Qwen-style tool tokens (e.g., `<‌tool_call>`), llguidance uses +/// a **byte-level lexer approach** with the following key concepts: +/// +/// ### 1. Token-Based, Not Byte-Based +/// The `Matcher.compute_mask()` returns a [`SimpleVob`](toktrie::SimpleVob) - a bit vector +/// where each bit represents whether a **token ID** is allowed. This is pre-computed +/// against the tokenizer's trie. +/// +/// ### 2. Special Token Marker (0xFF) +/// llguidance uses byte `0xFF` (TokTrie::SPECIAL_TOKEN_MARKER) to prefix special tokens +/// like `<|end_of_text|>`, `<|eot_id|>`, etc. This is because: +/// - `0xFF` is not valid UTF-8, so it never appears in regular text +/// - In Rust: `&[u8]` can contain 0xFF, but `&str` cannot +/// - Tokenizers like Qwen may embed special tokens as bytes like `[\xFF, b'[', b'1', b'2', b']']` +/// +/// ### 3. Qwen Tool Call Format Example +/// For models like Qwen3 that use `<‌tool_call>` delimiters: +/// +/// ```lark +/// start: tool* +/// tool: "<‌tool_call>" "\n" func "\n" "<‌/tool_call>" ("\n")* +/// func: %json {"type":"object","properties":{"name":...}} +/// ``` +/// +/// ### 4. Current Implementation in vLLM.rs +/// The [`src/core/runner.rs`](src/core/runner.rs) uses logits-based sampling: +/// ```ignore +/// // Apply mask: set disallowed tokens to -inf +/// for tok in 0..vocab_size { +/// if !mask.is_allowed(tok as u32) { +/// row[tok] = f32::NEG_INFINITY; +/// } +/// } +/// ``` +/// This is compatible with llguidance's token-level SimpleVob mask because: +/// - `mask.is_allowed(tok)` checks if token ID `tok` is in the allowed set +/// - The logits are modified to give -inf to disallowed tokens +/// - Sampling then only picks from allowed tokens +/// Sanitize string for Lark grammar - only allow ASCII characters +fn lark_quote(value: &str) -> String { + // Strip non-ASCII characters to prevent grammar parser errors + let sanitized: String = value + .chars() + .filter(|c| c.is_ascii()) + .collect(); + let escaped = sanitized.replace('\\', "\\\\").replace('"', "\\\""); + format!("\"{}\"", escaped) +} + +/// Build special token syntax for Lark grammar using token IDs +/// When token IDs are available, uses <[token_id]> syntax instead of string literals +/// This ensures alignment with the outbound parser's token-based detection +pub fn build_special_token_tag(token_ids: &std::collections::HashSet, fallback: &str) -> String { + if token_ids.is_empty() { + // Fall back to string representation when token IDs are not available + return lark_quote(fallback); + } + // Convert token IDs to Lark special token syntax <[id]> + // The format is: <[token_id]> which matches what the tokenizer expects + let ids: Vec = token_ids.iter().map(|id| format!("[{}]", id)).collect(); + format!("<{}>", ids.join(",")) +} + +/// Build tool call start tag using token IDs when available +pub fn build_tool_call_tag(start_token_ids: &std::collections::HashSet, start_token_str: &str) -> String { + build_special_token_tag(start_token_ids, start_token_str) +} + +/// Build tool call end tag using token IDs when available +pub fn build_tool_call_end_tag(end_token_ids: &std::collections::HashSet, end_token_str: &str) -> String { + build_special_token_tag(end_token_ids, end_token_str) +} + +/// Build TEXT pattern with explicit EOS token IDs using <[id]> syntax +/// The EOS tokens are alternated as optional termination: TEXT eos? +pub fn chat_text_expression_with_eos(special_tokens: &SpecialTokens) -> String { + let eos_token_ids = special_tokens.eos_ids(); + + // First check environment variable override + if let Ok(val) = std::env::var("VLLM_LLG_DEFAULT_TEXT") { + return format!("{}", val); + } + + // Build EOS alternation pattern using <[id]> syntax for token IDs + // LHS must be lowercase - literal tokens aren't allowed in TERMINAL rules + let eos_pattern = if eos_token_ids.is_empty() { + // Fallback to stop="" when no EOS tokens available + r#"start: text +text[stop=""]: /((?s).*?)/"#.to_string() + } else if eos_token_ids.len() == 1 { + format!(r#"start: text_with_eos +text_with_eos: TEXT eos? +TEXT: /(?s:.*)/ +eos: <[{}]>"#, eos_token_ids[0]) + } else { + let ids: Vec = eos_token_ids.iter().map(|id| format!("<[{}]>", id)).collect(); + let eos_alternation = ids.join(" | "); + format!(r#"start: text_with_eos +text_with_eos: TEXT eos? +TEXT: /(?s:.*)/ +eos: {}"#, eos_alternation) + }; + + eos_pattern +} + +/// Build TEXT pattern with stop="" attribute for proper EOS bounding +/// The stop="" attribute sets ends_at_eos: true so the parser can terminate at EOS +/// The [lazy] syntax is for rules, not terminals - options go AFTER the rule name, BEFORE the colon +pub fn chat_text_expression() -> String { + // First check environment variable override + if let Ok(val) = std::env::var("VLLM_LLG_DEFAULT_TEXT") { + return format!("{}", val); + } + + // Use a rule (lowercase) with stop="" attribute for proper EOS termination + // The stop="" tells llguidance to allow EOS token as a valid termination point + // Options go after the rule name, before the colon: text[stop=""]: /pattern/ + r#"start: text +text[stop=""]: /((?s).*?)/"#.to_string() +} + +/// Build grammar vec based on constraint and tool presence +/// Returns a Vec where the first element gets the start: rule +pub fn build_grammar_vec( + constraint_grammars: Vec, + tool_grammar: Option, + tool_choice_required: bool, +) -> Vec { + match (constraint_grammars.is_empty(), tool_grammar.is_some(), tool_choice_required) { + // No constraints, no tools → text only + (true, false, _) => { + let lark_exp = format!("start: text\ntext[stop=\"\"]: /((?s).*?)/"); + vec![TopLevelGrammar::from_lark(lark_exp)] + }, + + // No constraints, tools optional → TEXT | tool_call + (true, true, false) => { + let mut grammars = constraint_grammars; + grammars.push(tool_grammar.unwrap()); + grammars + } + + // No constraints, tools required → tool_call only + (true, true, true) => { + vec![tool_grammar.unwrap()] + } + + // Constraints present, no tools → constraint only + (false, false, _) => constraint_grammars, + + // Constraints present, tools optional → constraint | tool_call + (false, true, false) => { + let mut grammars = constraint_grammars; + grammars.push(tool_grammar.unwrap()); + grammars + } + + // Constraints present, tools required → constraint | tool_call + (false, true, true) => { + let mut grammars = constraint_grammars; + grammars.push(tool_grammar.unwrap()); + grammars + } + } +} + +/// Compose grammars based on constraint and tool settings +/// Returns a single TopLevelGrammar with proper precedence +/// This function takes the grammar that was built externally (with appropriate model-specific format) +/// and handles the alternation/composition logic +pub fn compose_grammars( + mut constraint_grammars: Vec, + tool_grammar: Option, + has_tools: bool, + tool_choice_required: bool, + forced_tool_name: Option, + max_tokens: Option, + special_tokens: &SpecialTokens, +) -> TopLevelGrammar { + crate::log_debug!("[llg] compose_grammars() called: constraints={:?}", constraint_grammars.len()); + crate::log_debug!("[llg] compose_grammars(): has_tools={}, tool_choice_required={}, forced_tool_name={:?}", has_tools, tool_choice_required, forced_tool_name); + + match ( + constraint_grammars.is_empty(), + tool_grammar.is_some(), + tool_choice_required, + forced_tool_name.is_some(), + ) { + // No constraint, no tools → text with EOS bounding + (true, false, _, _) => { + // Build TEXT pattern with explicit EOS token IDs + // This generates: start: text_with_eos, text_with_eos: TEXT eos?, TEXT: /pattern/, eos: <[id]> + let lark = chat_text_expression_with_eos(special_tokens); + crate::log_debug!("[llg] compose_grammars() -> text with EOS: {}", &lark); + TopLevelGrammar::from_lark_utf8(&lark) + } + + // No constraint, tools optional → tool_call | text with EOS + (true, true, false, false) => { + // Build text grammar with explicit EOS token IDs + let lark = chat_text_expression_with_eos(special_tokens); + let text_gram = TopLevelGrammar::from_lark(lark); + let tool_gram = tool_grammar.unwrap(); + let start_sep = "|".to_string(); + let merged = merge_top_level_grammars(vec![text_gram, tool_gram], max_tokens, Some(start_sep)); + crate::log_debug!("[llg] compose_grammars() -> ( text with EOS | tool_call )+"); + merged + } + + // No constraint, tools required → tool_call only + (true, true, true, _) => { + let tool_gram = tool_grammar.unwrap(); + crate::log_debug!("[llg] compose_grammars() -> tool_call only (tools required)"); + tool_gram + } + + // No constraint, tools optional, specific tool forced → tool_call only + (true, true, false, true) => { + let tool_gram = tool_grammar.unwrap(); + crate::log_debug!("[llg] compose_grammars() -> tool_call only (forced tool: {})", forced_tool_name.unwrap()); + tool_gram + } + + // Constraint only, no tools → constraint only + (false, false, _, _) => { + let constraint_gram = constraint_grammars.remove(0); + crate::log_debug!("[llg] compose_grammars() -> constraint only"); + constraint_gram + } + + // Constraint only, tools optional → tool_call | constraint + (false, true, false, false) => { + // Build combined grammar with constraint and tool_call + let constraint_gram = constraint_grammars.remove(0); + let tool_gram = tool_grammar.unwrap(); + // Build the merged grammar with constraint | tool_call + // Use merge_top_level_grammars with None separator (|) + merge_top_level_grammars(vec![constraint_gram, tool_gram], max_tokens, None) + } + + // Constraint only, tools required → tool_call | constraint + (false, true, true, _) => { + let constraint_gram = constraint_grammars.remove(0); + let tool_gram = tool_grammar.unwrap(); + merge_top_level_grammars(vec![constraint_gram, tool_gram], max_tokens, None) + } + + // Constraint only, specific tool forced → tool_call | constraint + (false, true, false, true) => { + let constraint_gram = constraint_grammars.remove(0); + let tool_gram = tool_grammar.unwrap(); + merge_top_level_grammars(vec![constraint_gram, tool_gram], max_tokens, None) + } + } +} + +pub type ParserFactory = LlgParserFactory; + +pub fn build_llg_factory( + tokenizer: Tokenizer, + vocab_size: Option, +) -> Result> { + let tokenizer_vocab = tokenizer.get_vocab_size(true); + let target_vocab = vocab_size.map(|v| { + if v < tokenizer_vocab { + crate::log_warn!( + "Requested vocab size {} is smaller than tokenizer vocab size {}. Using tokenizer size.", + v, + tokenizer_vocab + ); + tokenizer_vocab + } else { + v + } + }); + let env = ByteTokenizer::from_tokenizer(tokenizer)?.into_tok_env(target_vocab)?; + let factory = ParserFactory::new_simple(&env)?; + Ok(Arc::new(factory)) +} + +pub fn load_toktrie_from_path(path: impl AsRef) -> Result { + let tokenizer = ByteTokenizer::from_file(path)?; + let env = ByteTokenizerEnv::new(tokenizer, None)?; + Ok(env.tok_trie) +} + +/// WS regex pattern for Lark grammars - matches whitespace including spaces, tabs, newlines, carriage returns +pub fn lark_ws_regex() -> &'static str { + "/[ \\\\t\\\\r\\\\n]+/" +} + +/// Build Lark grammar string for tool calls +pub fn build_tool_call_lark(tools: &[Tool], schema_map: &std::sync::Arc>, start: &str, end: &str) -> String { + let mut obj_rules = String::new(); + for tool in tools { + let name = &tool.function.name; + let schema_str = serde_json::to_string(schema_map.get(name).unwrap_or(&json!({}))).unwrap_or_default(); + obj_rules.push_str(&format!("obj_{}: %json {}\n", name.replace("-", "_"), schema_str)); + } + + format!("{start} _WS? json_array _WS? {end}\njson_array: \"[\" obj (\",\" obj)* \"]\"\nobj:\n_WS: {}\n{}", lark_ws_regex(), obj_rules.trim_end()) +} + +/// Cache for precomputed mask slices to avoid expensive re-computation +#[derive(Clone, Default)] +pub struct SlicerCache { + cache: HashMap>, +} + +impl SlicerCache { + /// Get or compute a mask slice for a given position + pub fn get_or_compute(&mut self, pos: usize, compute_fn: impl FnOnce() -> Vec) -> &Vec { + if !self.cache.contains_key(&pos) { + self.cache.insert(pos, compute_fn()); + } + self.cache.get(&pos).expect("entry must exist after compute") + } + + /// Clear the cache + pub fn clear(&mut self) { + self.cache.clear(); + } +} pub struct GuidanceState { - // Placeholder for future implementation - _phantom: std::marker::PhantomData<()>, + matcher: Matcher, + /// Track llm tokens for speculative decoding recovery + llm_tokens: Vec, + /// Track llm bytes for rollback calculations + llm_bytes: usize, + /// Cache for precomputed mask slices + slicer_cache: SlicerCache, } impl GuidanceState { - pub fn new(_toktrie: Arc, _schema: Value) -> anyhow::Result { - // Stubbed out - guided decoding temporarily disabled - anyhow::bail!("Guided decoding is temporarily disabled due to llguidance API changes. \ - The TopLevelGrammar::from_json_schema method is no longer available in llguidance >= 0.6") + pub fn new_from_grammar(factory: Arc, grammar: &TopLevelGrammar) -> Result { + crate::log_debug!("[llg] GuidanceState::new_from_grammar() called"); + crate::log_trace!("[llg] Creating parser from grammar"); + let parser = factory.create_parser(grammar.clone())?; + crate::log_trace!("[llg] Creating Matcher from parser"); + let matcher = Matcher::new(Ok(parser)); + crate::log_info!("[llg] GuidanceState created successfully for grammar"); + + Ok(Self { + matcher, + llm_tokens: Vec::new(), + llm_bytes: 0, + slicer_cache: SlicerCache::default(), + }) + } + + /// Compute mask with caching for performance + pub fn compute_mask(&mut self) -> Result> { + crate::log_trace!("[llg] compute_mask() called"); + + if self.matcher.is_stopped() { + crate::log_trace!("[llg] compute_mask() - matcher stopped, returning None"); + return Ok(None); + } + let mask = self.matcher.compute_mask()?; + crate::log_trace!("[llg] compute_mask() - mask computed with {} valid tokens", mask.len()); + Ok(Some(mask)) + } + + /// Commit token and track for speculative decoding recovery + pub fn commit_token(&mut self, token: u32) -> Result<()> { + crate::log_trace!("[llg] commit_token(token={})", token); + + if !self.matcher.is_stopped() { + self.matcher.consume_token(token)?; + crate::log_trace!("[llg] Token {} consumed successfully", token); + self.llm_tokens.push(token); + self.llm_bytes += 4; + } else { + crate::log_trace!("[llg] commit_token() - matcher stopped, skipping"); + } + Ok(()) + } + + /// Get the number of committed tokens + pub fn num_tokens(&self) -> usize { + self.llm_tokens.len() + } + + /// Get the number of committed bytes + pub fn num_bytes(&self) -> usize { + self.llm_bytes + } + + /// Check if guidance is finished + pub fn is_finished(&self) -> bool { + self.matcher.is_stopped() + } + + /// Get the last committed token + pub fn last_token(&self) -> Option { + self.llm_tokens.last().copied() + } + + /// Validate token without consuming it (for re-sampling) + pub fn validate_token(&mut self, token: u32) -> bool { + if self.matcher.is_stopped() { + return true; + } + let result = self.matcher.validate_tokens(&[token]).unwrap_or(0); + let is_valid = result == 1; + if !is_valid { + crate::log_debug!("[llg] Token {} rejected by grammar", token); + } + is_valid + } + + /// Compute mask or return EOS token set if stopped + pub fn compute_mask_or_eos(&mut self) -> Result { + self.matcher.compute_mask_or_eos().map_err(Into::into) + } + + /// Fast-forward tokens without consuming them (for speculative decoding) + pub fn compute_ff_tokens(&mut self) -> Vec { + if self.matcher.is_stopped() { + return Vec::new(); + } + self.matcher.compute_ff_tokens() + } + + /// Fast-forward and consume tokens guaranteed to be accepted by the grammar + pub fn consume_ff_tokens(&mut self) -> Result, anyhow::Error> { + crate::log_debug!("[llg] consume_ff_tokens() called"); + + if self.matcher.is_stopped() { + crate::log_trace!("[llg] consume_ff_tokens() - matcher stopped, returning empty"); + return Ok(Vec::new()); + } + + let ff_tokens = self.matcher.compute_ff_tokens(); + crate::log_debug!("[llg] compute_ff_tokens() returned {} tokens", ff_tokens.len()); + + for &token in &ff_tokens { + crate::log_trace!("[llg] Consuming FF token {}", token); + self.matcher.consume_token(token)?; + self.llm_tokens.push(token); + self.llm_bytes += 4; + } + + crate::log_debug!("[llg] consume_ff_tokens() - successfully consumed {} tokens", ff_tokens.len()); + Ok(ff_tokens) + } + + /// Check if there are pending lexeme bytes to be consumed + pub fn has_pending_lexeme_bytes(&self) -> bool { + false + } + + /// Rollback to a previous state with byte tracking + pub fn rollback_to(&mut self, token_pos: usize, byte_pos: usize) -> Result<()> { + let tokens_to_rollback = self.llm_tokens.len().saturating_sub(token_pos); + if tokens_to_rollback > 0 { + self.matcher.rollback(tokens_to_rollback)?; + } + self.llm_tokens.truncate(token_pos); + self.llm_bytes = byte_pos; + Ok(()) + } + + /// Capture current state as rollback snapshot + pub fn capture_snapshot(&mut self) { } - pub fn compute_allowed_tokens(&mut self) -> anyhow::Result { - anyhow::bail!("Guided decoding is temporarily disabled") + /// Clear all state + pub fn clear(&mut self) { + self.llm_tokens.clear(); + self.llm_bytes = 0; + self.slicer_cache.clear(); } - pub fn commit_token(&mut self, _token: u32) -> anyhow::Result<()> { - anyhow::bail!("Guided decoding is temporarily disabled") + /// Get a reference to the slicer cache + pub fn slicer_cache(&mut self) -> &mut SlicerCache { + &mut self.slicer_cache + } + + /// Validate a sequence of tokens against the grammar + pub fn validate_tokens(&mut self, tokens: &[u32]) -> Option { + if self.matcher.is_stopped() { + return Some(tokens.len()); + } + match self.matcher.validate_tokens(tokens) { + Ok(count) => Some(count), + Err(_) => None, + } } } -pub struct AllowedTokens { - pub tokens: Vec, - pub is_stopped: bool, +/// Apply sparse mask bias to logits +/// Uses iter_set_entries to only iterate allowed tokens +pub fn _batch_mask_bias( + logits: &Tensor, + masks: &[(usize, SimpleVob)], + vocab_size: usize, +) -> candle_core::Result { + let batch_size = masks.len(); + + // Create bias vector initialized to -inf + let mut bias_data = vec![f32::NEG_INFINITY; batch_size * vocab_size]; + + // Fill in allowed tokens using sparse iteration + // masks is Vec<(batch_idx, SimpleVob)> where batch_idx is the sequence position in the batch + for (batch_idx, mask) in masks.iter() { + mask.iter_set_entries(|idx| { + if idx < vocab_size { + bias_data[*batch_idx * vocab_size + idx] = 0.0; + } + }); + } + + // Create bias tensor on same device as logits + let bias_tensor = Tensor::from_vec(bias_data, (batch_size, vocab_size), logits.device())?; + + // GPU tensor addition (no CPU copy) + logits.broadcast_add(&bias_tensor) } -pub fn build_toktrie_from_tokenizer_bytes(bytes: &[u8]) -> anyhow::Result { - // Try to build TokTrie from bytes - // The new API uses TokTrie::from() with TokRxInfo and words - // For now, return an error as the exact migration path needs investigation - anyhow::bail!("TokTrie construction from tokenizer bytes is temporarily disabled. \ - The TokTrie::from_huggingface_bytes method is no longer available in toktrie >= 1.0. \ - Input bytes length: {}", bytes.len()) +/// Two-stage validation with early exit +/// Stage 1: Sample and validate token +/// Stage 2: Only compute mask if token is invalid +pub fn _early_exit_validate( + guidance_states: &mut HashMap, + seq_ids: &[usize], + tokens: &mut [u32], + logits: &Tensor, + vocab_size: usize, + _factory: &Arc, + sampling: &Sampling, + logit_processor: &LogitsProcessor, +) -> candle_core::Result<()> { + for (seq_idx, seq_id) in seq_ids.iter().enumerate() { + let token = tokens[seq_idx]; + + if let Some(state) = guidance_states.get_mut(seq_id) { + // Stage 1: Validate token + if state.validate_token(token) { + // Early exit - token is valid, consume it + state.commit_token(token).map_err(|e| candle_core::Error::Msg(e.to_string()))?; + continue; + } + + crate::log_debug!("[llg] Token {} is invalid, computing mask for seq {}", token, seq_id); + + // Stage 2: Token is invalid, compute mask and re-sample + let mask = match state.compute_mask_or_eos() { + Ok(m) => m, + Err(e) => { + crate::log_error!("[llg] Unable to compute mask for token {} due to {}", token, e); + continue; + } + }; + + crate::log_debug!("[llg] Applying bias to logits for seq {}", seq_id); + + // Build bias vector using sparse iteration + let mut acc = vec![f32::NEG_INFINITY; vocab_size]; + mask.iter_set_entries(|idx| { + if idx < acc.len() { + acc[idx] = 0.0; + } + }); + + // Get current sequence's logits as 1D tensor - MUST CLONE to avoid cross-contamination + let row_start = seq_idx * vocab_size; + let row_end = row_start + vocab_size; + let logits_vec = logits.flatten_all()?.to_vec1::()?; + let mut row_vec = logits_vec.clone(); // Clone to avoid modifying original + let row = &mut row_vec[row_start..row_end]; + + // Apply bias directly to this sequence's row + for tok in 0..vocab_size { + if acc[tok] != 0.0 { + row[tok] = f32::NEG_INFINITY; + } + } + + // Create 1D tensor for just this sequence + let biased_row = Tensor::from_vec(row_vec[row_start..row_end].to_vec(), (vocab_size,), logits.device())?; + + // Re-sample just this sequence from the biased 1D logits + let re_sampled = logit_processor.sample_with_strategy(&biased_row, sampling)?; + tokens[seq_idx] = re_sampled[0]; // 1D output, first (only) element + + crate::log_debug!("[llg] Consuming re-sampled token {} for seq {}", tokens[seq_idx], seq_id); + + // Commit the re-sampled token + state.commit_token(tokens[seq_idx]).map_err(|e| candle_core::Error::Msg(e.to_string()))?; + } else { + crate::log_debug!("[llg] No guidance state for seq {}", seq_id); + } + } + + Ok(()) } -pub fn load_toktrie_from_path(_: &Path) -> Option { - // Temporarily disabled - returns None - // crate::log_warn!("load_toktrie_from_path is disabled: {:?}", path); - None +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sanitize_to_ascii() { + let input = "hello"; + let sanitized = sanitize_to_ascii(input); + assert_eq!(sanitized, "hello"); + } + + #[test] + fn test_sanitize_utf8_valid() { + let input = "hello\x00\x01world"; + let sanitized = sanitize_utf8_valid(input); + assert_eq!(sanitized, "helloworld"); + } + + #[test] + fn test_grammar_builder_single_alternative() { + let grammar = GrammarBuilder::new() + .alternative(TopLevelGrammar::from_lark("start: 'a'".to_string())) + .build(); + assert!(grammar.grammars.len() > 0); + } + + #[test] + fn test_grammar_builder_multiple_alternatives() { + let grammar = GrammarBuilder::new() + .alternative(TopLevelGrammar::from_lark("start: 'a'".to_string())) + .alternative(TopLevelGrammar::from_lark("start: 'b'".to_string())) + .build(); + let lark_str = get_lark_from_top_level_grammar(&grammar); + assert!(lark_str.contains("start: ( 'a' | 'b' )+"), "Expected direct alternation"); + } + + #[test] + fn test_grammar_builder_with_max_tokens() { + let grammar = GrammarBuilder::new() + .alternative(TopLevelGrammar::from_lark("start: 'test'".to_string())) + .max_tokens(100) + .build(); + assert_eq!(grammar.max_tokens, Some(100)); + } + + #[test] + fn test_grammar_builder_default_text() { + let grammar = GrammarBuilder::new().build(); + let lark_str = get_lark_from_top_level_grammar(&grammar); + assert!(lark_str.contains("start: text"), "Expected default text pattern"); + } + + #[test] + fn test_merge_top_level_grammars_direct_alternation() { + // Test that merge_top_level_grammars produces direct alternation without rule_N indirection + let gram1 = TopLevelGrammar::from_lark("start: 'a'".to_string()); + let gram2 = TopLevelGrammar::from_lark("start: 'b'".to_string()); + // Use None for default separator (|) + let result = merge_top_level_grammars(vec![gram1, gram2], None, None); + + // Get the combined Lark string + let lark_str = get_lark_from_top_level_grammar(&result); + + // Verify that start: directly alternates 'a' | 'b' without rule_N indirection + assert!(lark_str.contains("start: ( 'a' | 'b' )+"), "Expected direct alternation in start rule: {}", lark_str); + // Verify that rule_N indirection is NOT present + assert!(!lark_str.contains("rule_0:"), "Should not contain rule_0 indirection"); + assert!(!lark_str.contains("rule_1:"), "Should not contain rule_1 indirection"); + } + + #[test] + fn test_merge_top_level_grammars_with_text_and_tool() { + // Test the actual TEXT | tool_call scenario from the issue + let lark = format!("start: TEXT\n{}", chat_text_expression()); + let text_gram = TopLevelGrammar::from_lark(lark); + let tool_gram = TopLevelGrammar::from_lark("start: tool_call\ntool_call: \"test\"".to_string()); + // Use None for default separator (|) + let result = merge_top_level_grammars(vec![text_gram, tool_gram], None, None); + + // Get the combined Lark string + let lark_str = get_lark_from_top_level_grammar(&result); + + // Verify that start: directly alternates TEXT | tool_call + assert!(lark_str.contains("start: ( TEXT | tool_call )+"), "Expected direct alternation: {}", lark_str); + // Verify that rule_N indirection is NOT present + assert!(!lark_str.contains("rule_0:"), "Should not contain rule_0 indirection"); + assert!(!lark_str.contains("rule_1:"), "Should not contain rule_1 indirection"); + } + + #[test] + fn test_merge_top_level_grammars_with_grammar_without_start() { + // Verify that when merging a grammar without start: line, it gets properly handled + let gram1 = TopLevelGrammar::from_lark("start: 'a'\n'a': 'a'".to_string()); + let gram2 = TopLevelGrammar::from_lark("'tool': 'call'\ntool: %json {\"type\":\"object\"}".to_string()); + // Use None for default separator (|) + let result = merge_top_level_grammars(vec![gram1, gram2], None, None); + + // Get the combined Lark string + let lark_str = get_lark_from_top_level_grammar(&result); + + // Should still have direct alternation at start + assert!(lark_str.contains("start:"), "Expected start rule in merged grammar"); + // The tool grammar should be properly included + assert!(lark_str.contains("'tool': 'call'"), "Expected tool content in merged grammar"); + } +} + +#[cfg(test)] +mod tool_grammar_builder_tests { + use super::*; + use crate::tools::ToolBuilder; + use std::collections::HashSet; + + #[test] + fn test_tool_grammar_builder_json_single_tool() { + let tools = vec![ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query".to_string(), "string".to_string(), "Search query".to_string(), true) + .build()]; + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("".to_string()) + .end_tag("".to_string()) + .start_is_special(false) + .end_is_special(false) + .build_json(); + assert!(grammar.grammars.len() > 0); + } + + #[test] + fn test_tool_grammar_builder_json_multiple_tools() { + let tools = vec![ + ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query".to_string(), "string".to_string(), "Search query".to_string(), true) + .build(), + ToolBuilder::new("weather".to_string(), "Get weather".to_string()) + .param("city".to_string(), "string".to_string(), "City name".to_string(), true) + .build(), + ]; + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("".to_string()) + .end_tag("".to_string()) + .start_is_special(false) + .end_is_special(false) + .build_json(); + let lark_str = get_lark_from_top_level_grammar(&grammar); + assert!(lark_str.contains("obj_search:"), "Should contain obj_search rule"); + assert!(lark_str.contains("obj_weather:"), "Should contain obj_weather rule"); + } + + #[test] + fn test_tool_grammar_builder_xml_single_tool() { + let tools = vec![ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query".to_string(), "string".to_string(), "Search query".to_string(), true) + .build()]; + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("".to_string()) + .end_tag("".to_string()) + .start_is_special(false) + .end_is_special(false) + .build_xml(); + assert!(grammar.grammars.len() > 0); + } + + #[test] + fn test_tool_grammar_builder_xml_multiple_tools() { + let tools = vec![ + ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query".to_string(), "string".to_string(), "Search query".to_string(), true) + .build(), + ToolBuilder::new("weather".to_string(), "Get weather".to_string()) + .param("city".to_string(), "string".to_string(), "City name".to_string(), true) + .build(), + ]; + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("".to_string()) + .end_tag("".to_string()) + .start_is_special(false) + .end_is_special(false) + .build_xml(); + let lark_str = get_lark_from_top_level_grammar(&grammar); + assert!(lark_str.contains("tool_content: tool_0 | tool_1"), "Expected tool alternation"); + } + + #[test] + fn test_tool_grammar_builder_with_token_ids() { + let tools = vec![ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query".to_string(), "string".to_string(), "Search query".to_string(), true) + .build()]; + let mut start_ids = HashSet::new(); + start_ids.insert(151657); + let mut end_ids = HashSet::new(); + end_ids.insert(151658); + + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("".to_string()) + .end_tag("".to_string()) + .start_is_special(false) + .end_is_special(false) + .start_token_ids(Some(start_ids)) + .end_token_ids(Some(end_ids)) + .build_json(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + assert!(lark_str.contains("<[151657]>"), "Should contain start token ID"); + assert!(lark_str.contains("<[151658]>"), "Should contain end token ID"); + } + + #[test] + fn test_tool_grammar_builder_special_tags() { + let tools = vec![ToolBuilder::new("search".to_string(), "Search the web".to_string()) + .param("query".to_string(), "string".to_string(), "Search query".to_string(), true) + .build()]; + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("".to_string()) + .end_tag("".to_string()) + .start_is_special(true) + .end_is_special(true) + .build_json(); + let lark_str = get_lark_from_top_level_grammar(&grammar); + assert!(lark_str.contains(""), "Should contain special start tag"); + assert!(lark_str.contains(""), "Should contain special end tag"); + } + + #[test] + fn test_tool_grammar_builder_empty_tools_json() { + let grammar = ToolGrammarBuilder::new() + .tools(&[]) + .start_tag("".to_string()) + .end_tag("".to_string()) + .start_is_special(false) + .end_is_special(false) + .build_json(); + let lark_str = get_lark_from_top_level_grammar(&grammar); + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains("obj: %json"), "Should have obj rule with generic schema"); + } + + #[test] + fn test_tool_grammar_builder_empty_tools_xml() { + let grammar = ToolGrammarBuilder::new() + .tools(&[]) + .start_tag("".to_string()) + .end_tag("".to_string()) + .start_is_special(false) + .end_is_special(false) + .build_xml(); + let lark_str = get_lark_from_top_level_grammar(&grammar); + assert!(lark_str.contains("start: tool_call"), "Should have start: tool_call"); + assert!(lark_str.contains("tool_content:"), "Should have tool_content rule"); + } + + #[test] + fn test_tool_grammar_builder_complex_schema() { + let tools = vec![ToolBuilder::new("edit_file".to_string(), "Edit a file".to_string()) + .param("file_path".to_string(), "string".to_string(), "Path to the file".to_string(), true) + .param("old_string".to_string(), "string".to_string(), "String to replace".to_string(), true) + .param("new_string".to_string(), "string".to_string(), "Replacement string".to_string(), true) + .param("max_replacements".to_string(), "integer".to_string(), "Maximum replacements".to_string(), false) + .build()]; + + let grammar = ToolGrammarBuilder::new() + .tools(&tools) + .start_tag("".to_string()) + .end_tag("".to_string()) + .start_is_special(false) + .end_is_special(false) + .build_xml(); + + let lark_str = get_lark_from_top_level_grammar(&grammar); + assert!(lark_str.contains("param_0_0:"), "Should have param_0_0 rule (file_path - required)"); + assert!(lark_str.contains("param_0_1:"), "Should have param_0_1 rule (old_string - required)"); + assert!(lark_str.contains("param_0_2:"), "Should have param_0_2 rule (new_string - required)"); + assert!(lark_str.contains("param_0_3:"), "Should have param_0_3 rule (max_replacements - optional)"); + } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 124b63ad..e3f27065 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -14,6 +14,7 @@ pub mod image; pub mod kvcache_allocator; pub mod logits_processor; pub mod progress; +pub mod special_tokens; use crate::core::GenerationOutput; use crate::models::gemma3::config::Gemma3Config; use crate::utils::config::MoEConfig; @@ -24,7 +25,7 @@ use crate::utils::downloader::ModelPaths; use crate::utils::gguf_helper::{get_gguf_info, GGUFInfo}; use candle_core::utils::{cuda_is_available, metal_is_available}; use candle_core::{DType, Device, Result}; -use config::{Config, EngineConfig, EosTokenId, GenerationConfig, TokenizerConfig}; +use config::{Config, EngineConfig, GenerationConfig, TokenizerConfig}; use std::collections::HashMap; use std::path::{Path, PathBuf}; use tokenizers::Tokenizer; @@ -174,10 +175,10 @@ pub fn config_from_gguf( let eos_token_id = md_get("tokenizer.ggml.eos_token_id"); - let eos_token_id = if eos_token_id.is_ok() { - EosTokenId::Single(eos_token_id.unwrap().to_u32()?) + let _eos_token_id = if eos_token_id.is_ok() { + Some(vec![eos_token_id.unwrap().to_u32()?]) } else { - EosTokenId::Multiple(vec![]) + None }; // ---------------- RoPE scaling -------------------------- @@ -375,7 +376,7 @@ pub fn config_from_gguf( final_logit_softcapping: None, tie_word_embeddings: Some(!has_output_weight), bos_token_id, - eos_token_id: Some(eos_token_id), + eos_token_id: None, use_sliding_window: None, sliding_window: None, max_window_layers: None, @@ -437,8 +438,10 @@ fn merge_multimodal_top_level_config( if let Some(eos) = raw_root.get("eos_token_id") { if !eos.is_null() { - if let Ok(eos_token_id) = serde_json::from_value::(eos.clone()) { - config.eos_token_id = Some(eos_token_id); + if let Ok(eos_ids) = serde_json::from_value::>(eos.clone()) { + config.eos_token_id = Some(eos_ids); + } else if let Ok(eos_id) = serde_json::from_value::(eos.clone()) { + config.eos_token_id = Some(vec![eos_id]); } } } @@ -656,6 +659,7 @@ pub fn init_config_tokenizer( .map_err(candle_core::Error::wrap)?; let mut config: Config = serde_json::from_value(config_value) .map_err(candle_core::Error::wrap)?; + // Gemma3Config already uses Vec for eos_token_id config.eos_token_id = gemma3_cfg.eos_token_id; config } diff --git a/src/utils/special_tokens.rs b/src/utils/special_tokens.rs new file mode 100644 index 00000000..699715c9 --- /dev/null +++ b/src/utils/special_tokens.rs @@ -0,0 +1,548 @@ +use std::collections::HashSet; +use image::EncodableLayout; +use tokenizers::tokenizer::Tokenizer; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)] +pub enum Category { + Eos, + Pad, + Bos, + Sep, + Cls, + Mask, + Tool, + Function, + Parameter, + Role, + ContentType, + Reasoning, + Other, +} + + +impl Category { + pub fn search_strings(&self) -> Vec { + match self { + Self::Eos => vec![ + "" , "" , "<,eos,>", "<,end_of_text,>" , "<,end,>" , + "<,eot,>" , "<,eot_id,>" , "<,eom_id,>" , "<,end_of_turn,>" , + "<,endoftext,>" , "<,endofsequence,>" , "[EOS]", "<|im_end|>", + "<|box_end|>", "<|object_ref_end|>","<|quad_end|>", "<|endoftext|>", + "<|vision_end|>", "<|eot|>", "<|python_end|>", "<|end_of_text|>", + "<|header_end|>", "<|eom|>" + ], + Self::Bos => vec![ + "", "" , "<,bos,>", "<,bos_token,>" , + "<,begin_of_text,>" , "<,startoftext,>" , "<,start,>" , + "<,im_start,>" , "[BOS]", "<|box_start|>", "<|im_start|>", + "<|object_ref_start|>", "<|quad_start|>", "<|vision_start|>", + "<|python_start|>", "<|begin_of_text|>", "<|header_start|>" + ], + Self::Pad => vec![ + "" , "<,pad,>" , "" , "<,pad_token,>" , + "[PAD]" , "", "<|image_pad|>", "<|video_pad|>", + "<|vision_pad|>", + ], + Self::Sep => vec![ + "" , "<,sep,>" , "<,separator,>" , "[SEP]" + ], + Self::Cls => vec![ + "" , "<,cls,>" , "[CLS]" , "" + ], + Self::Mask => vec![ + "" , "<,mask,>" , "[MASK]" , "" , + "<,mask_token,>" , "<,infill_mask,>" , "<,extra_id_0,>" , + "" , "" + ], + Self::Role => vec![ + "<,system,>" , "<,user,>" , "<,assistant,>" , "<,role,>" , + "<,critic,>" , "<,observer,>" , "" , + "" , "" , "" + ], + Self::ContentType => vec![ + "<,content_type,>" , "<,content,>" , "<,text,>" , "<,code,>" , + "<,json,>" , "<,markdown,>" , "<,output,>" , "<,html,>" , + "<,data,>" , "<,datatype,>", "<|image|>" + ], + Self::Tool=> vec![ + "<|python_tag|>", "<|eom_id|>", + "", "", + "", "", + "[TOOL_CALLS]", "]", + "", "", + + ], + Self::Function => vec![ + "<,function,>" , "" , "<,functions,>" , "<,fn,>" , "" , + "<,tool,>" , "" , "<,tools,>" , "<,api,>" , + "<,invoke,>" , "<,function_call,>" , "<,tool_call,>" , + "<,function_call_json,>" + ], + Self::Parameter => vec![ + "" , "<,parameter,>" , "<,parameters,>" , + "<,args,>" , "<,arguments,>" , "" , + "" , "<,params,>" + ], + Self::Reasoning => vec![ + " magnesium " , " magnesia " , "" , "" , + "<,thinking,>" , "<,reasoning,>" , "" , "<,reason,>" , + "<,thought,>" , "<,thoughts,>" , "<,internal,>" , + "" , "" , "<,reflect,>" , + "" , "<,chain_of_thought,>" , "<,analysis,>" , + "<,rationale,>" , "<,explanation,>", "", "" + ], + Self::Other => vec![ + "<,eos_token,>" , "<,unk_token,>" , "" , "[UNK]" , + "<,start_header_id,>" , "<,end_header_id,>" , + "<,metadata,>" , "<,special,>" + ], + }.iter().map(|e| e.to_string()).collect() + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VocabSource { + Special, + Added, + Common, +} + +#[derive(Debug, Clone)] +pub struct SpecialToken { + pub category: Category, + pub id: u32, + pub content: Vec, + pub source: VocabSource, + pub normalized: bool, +} + +impl SpecialToken { + pub fn string(&self) -> String { + self.content.clone() + .into_iter() + .filter(|b| b.is_ascii()) + .map(|b| b as char) + .collect() + } +} + + +#[derive(Debug, Clone, Default)] +pub struct SpecialTokens { + token_set: Vec, +} + +// Private macros for internal implementation +macro_rules! filter_by_category { + ($self:ident, $cat:ident) => { + $self.token_set.iter().filter(|t| t.category == Category::$cat).cloned().collect::>() + }; +} + +macro_rules! filter_by_category_source { + ($self:ident, $cat:ident, $src:ident) => { + $self.token_set.iter() + .filter(|t| t.category == Category::$cat && t.source == VocabSource::$src) + .cloned() + .collect::>() + }; +} + +impl SpecialTokens { + // Public category accessors + pub fn eos(&self) -> Vec { filter_by_category!(self, Eos) } + pub fn pad(&self) -> Vec { filter_by_category!(self, Pad) } + pub fn bos(&self) -> Vec { filter_by_category!(self, Bos) } + pub fn sep(&self) -> Vec { filter_by_category!(self, Sep) } + pub fn cls(&self) -> Vec { filter_by_category!(self, Cls) } + pub fn mask(&self) -> Vec { filter_by_category!(self, Mask) } + pub fn tool(&self) -> Vec { filter_by_category!(self, Tool) } + pub fn function(&self) -> Vec { filter_by_category!(self, Function) } + pub fn parameter(&self) -> Vec { filter_by_category!(self, Parameter) } + pub fn role(&self) -> Vec { filter_by_category!(self, Role) } + pub fn content_type(&self) -> Vec { filter_by_category!(self, ContentType) } + pub fn reasoning(&self) -> Vec { filter_by_category!(self, Reasoning) } + pub fn other(&self) -> Vec { filter_by_category!(self, Other) } + + // Public ID accessors returning Vec + pub fn eos_ids(&self) -> Vec { self.eos().iter().map(|t| t.id).collect() } + pub fn pad_ids(&self) -> Vec { self.pad().iter().map(|t| t.id).collect() } + pub fn bos_ids(&self) -> Vec { self.bos().iter().map(|t| t.id).collect() } + pub fn sep_ids(&self) -> Vec { self.sep().iter().map(|t| t.id).collect() } + pub fn cls_ids(&self) -> Vec { self.cls().iter().map(|t| t.id).collect() } + pub fn mask_ids(&self) -> Vec { self.mask().iter().map(|t| t.id).collect() } + pub fn tool_ids(&self) -> Vec { self.tool().iter().map(|t| t.id).collect() } + pub fn function_ids(&self) -> Vec { self.function().iter().map(|t| t.id).collect() } + pub fn parameter_ids(&self) -> Vec { self.parameter().iter().map(|t| t.id).collect() } + pub fn role_ids(&self) -> Vec { self.role().iter().map(|t| t.id).collect() } + pub fn content_type_ids(&self) -> Vec { self.content_type().iter().map(|t| t.id).collect() } + pub fn reasoning_ids(&self) -> Vec { self.reasoning().iter().map(|t| t.id).collect() } + pub fn other_ids(&self) -> Vec { self.other().iter().map(|t| t.id).collect() } + + // Public ID accessors returning HashSet for O(1) lookup + pub fn eos_ids_set(&self) -> HashSet { self.eos_ids().into_iter().collect() } + pub fn pad_ids_set(&self) -> HashSet { self.pad_ids().into_iter().collect() } + pub fn bos_ids_set(&self) -> HashSet { self.bos_ids().into_iter().collect() } + pub fn sep_ids_set(&self) -> HashSet { self.sep_ids().into_iter().collect() } + pub fn cls_ids_set(&self) -> HashSet { self.cls_ids().into_iter().collect() } + pub fn mask_ids_set(&self) -> HashSet { self.mask_ids().into_iter().collect() } + pub fn tool_ids_set(&self) -> HashSet { self.tool_ids().into_iter().collect() } + pub fn function_ids_set(&self) -> HashSet { self.function_ids().into_iter().collect() } + pub fn parameter_ids_set(&self) -> HashSet { self.parameter_ids().into_iter().collect() } + pub fn role_ids_set(&self) -> HashSet { self.role_ids().into_iter().collect() } + pub fn content_type_ids_set(&self) -> HashSet { self.content_type_ids().into_iter().collect() } + pub fn reasoning_ids_set(&self) -> HashSet { self.reasoning_ids().into_iter().collect() } + pub fn other_ids_set(&self) -> HashSet { self.other_ids().into_iter().collect() } + + /// Get all token IDs across all categories as HashSet for O(1) lookup + pub fn all_ids_set(&self) -> HashSet { self.token_set.iter().map(|t| t.id).collect() } + + // Public string accessors + pub fn eos_strings(&self) -> Vec { self.eos().iter().map(|t| t.string()).collect() } + pub fn pad_strings(&self) -> Vec { self.pad().iter().map(|t| t.string()).collect() } + pub fn bos_strings(&self) -> Vec { self.bos().iter().map(|t| t.string()).collect() } + pub fn sep_strings(&self) -> Vec { self.sep().iter().map(|t| t.string()).collect() } + pub fn cls_strings(&self) -> Vec { self.cls().iter().map(|t| t.string()).collect() } + pub fn mask_strings(&self) -> Vec { self.mask().iter().map(|t| t.string()).collect() } + pub fn tool_strings(&self) -> Vec { self.tool().iter().map(|t| t.string()).collect() } + pub fn function_strings(&self) -> Vec { self.function().iter().map(|t| t.string()).collect() } + pub fn parameter_strings(&self) -> Vec { self.parameter().iter().map(|t| t.string()).collect() } + pub fn role_strings(&self) -> Vec { self.role().iter().map(|t| t.string()).collect() } + pub fn content_type_strings(&self) -> Vec { self.content_type().iter().map(|t| t.string()).collect() } + pub fn reasoning_strings(&self) -> Vec { self.reasoning().iter().map(|t| t.string()).collect() } + pub fn other_strings(&self) -> Vec { self.other().iter().map(|t| t.string()).collect() } + + /// Search for tokens by ID, substring, category, and source. + /// All parameters are optional - use None to skip that filter. + pub fn search( + &self, + id: Option, + substring: Option<&str>, + category: Option, + source: Option, + ) -> Vec { + let mut results = Vec::new(); + + for token in &self.token_set { + // Filter by ID if specified + if let Some(target_id) = id { + if token.id != target_id { + continue; + } + } + + // Filter by substring if specified + if let Some(sub) = substring { + let token_str = token.string(); + if !token_str.contains(sub) { + continue; + } + } + + // Filter by category if specified + if let Some(cat) = category { + if token.category != cat { + continue; + } + } + + // Filter by source if specified + if let Some(src) = source { + if token.source != src { + continue; + } + } + + results.push(token.clone()); + } + + results + } + + /// Create SpecialTokens from a tokenizer + pub fn new(tokenizer: &Tokenizer) -> Self { + let mut token_set: Vec = Vec::new(); + let mut seen_ids: HashSet = HashSet::new(); + + // Step 1: Process all tokens from tokenizer (added + base vocab) + // First, get added tokens + for (id, added_token) in tokenizer.get_added_tokens_decoder() { + if seen_ids.contains(&id) { + continue; + } + seen_ids.insert(id); + + // Determine category from content + let category = Self::categorize_by_content(&added_token.content); + + // Determine source + let source = if added_token.special { + VocabSource::Special + } else { + VocabSource::Added + }; + + token_set.push(SpecialToken { + category, + id, + content: added_token.content.as_bytes().to_vec(), + source, + normalized: added_token.normalized + }); + } + + // Step 2: Add tokens from base vocabulary that match known patterns + let vocab = tokenizer.get_vocab(true); + for (token_str, id) in vocab { + // Find potential duplicates of our special tokens in the common vocab + if seen_ids.contains(&id) || !token_set.iter().any( + |f| String::from_utf8(f.content.clone()).unwrap() == token_str.to_string() + ) { + continue; + } + + // Try to categorize by content + let category = Self::categorize_by_content(token_str.as_str()); + if category != Category::Other { + token_set.push(SpecialToken { + category, + id, + content: token_str.as_bytes().to_vec(), + source: VocabSource::Common, + normalized: false, + }); + } + } + + // Sort by id for consistent ordering + token_set.sort_by_key(|t| t.id); + + Self { token_set } + } + fn categorize_by_content(content: &str) -> Category { + for cat in &[Category::Eos, Category::Pad, Category::Bos, Category::Sep, + Category::Cls, Category::Mask, Category::Tool, Category::Function, + Category::Parameter, Category::Role, Category::ContentType, Category::Reasoning] { + if cat.search_strings().iter().any(|s| s == content) { + return *cat; + } + } + Category::Other + } + + /// Create SpecialTokens from a tokenizer file path + pub fn new_from_file(tokenizer_path: &str) -> Self { + let tokenizer = Tokenizer::from_file(tokenizer_path).expect("Failed to load tokenizer"); + Self::new(&tokenizer) + } + + /// Get tool start token IDs (tokens categorized as Tool category that are start markers) + /// Start markers are those that don't start with Vec { + self.tool() + .iter() + .filter(|t| { + let s = t.string(); + !s.starts_with(" Vec { + self.tool() + .iter() + .filter(|t| { + let s = t.string(); + s.starts_with(" HashSet { + self.tool_start_ids().into_iter().collect() + } + + /// Get tool end token IDs as HashSet for O(1) lookup + pub fn tool_end_ids_set(&self) -> HashSet { + self.tool_end_ids().into_iter().collect() + } + + /// Get tool start token SpecialToken if available + /// Returns the SpecialToken object containing both ID and string representation + pub fn tool_start_token(&self) -> Option { + self.tool().iter().find(|t| { + let s = t.string(); + !s.starts_with(" Option { + self.tool().iter().find(|t| { + let s = t.string(); + s.starts_with(" Option<(SpecialToken, SpecialToken)> { + let start = self.tool_start_token()?; + let end = self.tool_end_token()?; + Some((start, end)) + } + + /// Get all tokens + pub fn all_tokens(&self) -> Vec { + self.token_set.clone() + } + + /// Get all special tokens (Special or Added source) + pub fn all_special(&self) -> Vec { + self.token_set.iter() + .filter(|t| t.source == VocabSource::Special || t.source == VocabSource::Added) + .cloned() + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_special_token_string_conversion() { + let token = SpecialToken { + category: Category::Eos, + id: 2, + content: b"".to_vec(), + source: VocabSource::Added, + normalized: false, + }; + assert_eq!(token.string(), ""); + } + + #[test] + fn test_categorize_special_tokens() { + let tokenizer = Tokenizer::from_file("tests/fixtures/tokenizer.json").ok(); + + if let Some(tok) = tokenizer { + let special_tokens = SpecialTokens::new(&tok); + + // Check that we have some tokens stored + assert!(!special_tokens.eos().is_empty() || + !special_tokens.pad().is_empty()); + } + } + + #[test] + fn test_token_uniqueness() { + let tokenizer = Tokenizer::from_file("tests/fixtures/tokenizer.json").ok(); + + if let Some(tok) = tokenizer { + let special_tokens = SpecialTokens::new(&tok); + + // Check that no category has duplicate IDs + let all_ids: Vec = special_tokens.eos_ids() + .into_iter() + .chain(special_tokens.pad_ids()) + .chain(special_tokens.bos_ids()) + .collect(); + + let unique_ids: HashSet = all_ids.iter().cloned().collect(); + assert_eq!(all_ids.len(), unique_ids.len(), "Duplicate token IDs found"); + } + } + + #[test] + fn test_search_by_id() { + let tokenizer = Tokenizer::from_file("tests/fixtures/tokenizer.json").ok(); + + if let Some(tok) = tokenizer { + let special_tokens = SpecialTokens::new(&tok); + + // Search for a specific ID + let results = special_tokens.search(Some(2), None, None, None); + + // Each result should have the matching ID + for result in &results { + assert_eq!(result.id, 2); + } + } + } + + #[test] + fn test_search_by_content() { + let tokenizer = Tokenizer::from_file("tests/fixtures/tokenizer.json").ok(); + + if let Some(tok) = tokenizer { + let special_tokens = SpecialTokens::new(&tok); + + // Search for tokens containing "end" + let results = special_tokens.search(None, Some("end"), None, None); + + // Each result should contain the search string + for result in &results { + assert!(result.string().contains("end")); + } + } + } + + #[test] + fn test_search_by_category() { + let tokenizer = Tokenizer::from_file("tests/fixtures/tokenizer.json").ok(); + + if let Some(tok) = tokenizer { + let special_tokens = SpecialTokens::new(&tok); + + // Search for EOS tokens + let results = special_tokens.search(None, None, Some(Category::Eos), None); + + // Each result should be an EOS token + for result in &results { + assert_eq!(result.category, Category::Eos); + } + } + } + + #[test] + fn test_search_by_source() { + let tokenizer = Tokenizer::from_file("tests/fixtures/tokenizer.json").ok(); + + if let Some(tok) = tokenizer { + let special_tokens = SpecialTokens::new(&tok); + + // Search for tokens with Added source + let results = special_tokens.search(None, None, None, Some(VocabSource::Added)); + + // Each result should have Added source + for result in &results { + assert_eq!(result.source, VocabSource::Added); + } + } + } + + #[test] + fn test_combined_search() { + let tokenizer = Tokenizer::from_file("tests/fixtures/tokenizer.json").ok(); + + if let Some(tok) = tokenizer { + let special_tokens = SpecialTokens::new(&tok); + + // Search for EOS tokens with specific substring + let results = special_tokens.search(None, Some("end"), Some(Category::Eos), None); + + // Each result should match all criteria + for result in &results { + assert_eq!(result.category, Category::Eos); + assert!(result.string().contains("end")); + } + } + } +}