diff --git a/Cargo.toml b/Cargo.toml index b1cfdf4f..58db529b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vllm-rs" -version = "0.8.10" +version = "0.8.11" edition = "2021" default-run = "vllm-rs" @@ -15,7 +15,8 @@ itertools = "0.13.0" akin = "0.4.0" indicatif = "0.17.11" serde_json = "1.0.108" -llguidance = "0.6" +llguidance = { version = "1.2.0", default-features = false, features = ["lark"] } +toktrie_hf_tokenizers = "1.2.0" toktrie = "1.4" half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] } tokio = { version = "1.38.0", features = ["sync"] } diff --git a/docs/claude_code.md b/docs/claude_code.md index 3cb68d12..83117b0a 100644 --- a/docs/claude_code.md +++ b/docs/claude_code.md @@ -18,6 +18,22 @@ python3 -m vllm_rs.server --m miromind-ai/MiroThinker-v1.5-30B --d 0,1 --server ## 2) Configure Claude Code +Install claude code + +```shell +npm install -g @anthropic-ai/claude-code +``` + +Export config + +```shell +export ANTHROPIC_BASE_URL="http://127.0.0.1:8000" +export ANTHROPIC_AUTH_TOKEN="sk-dummy" +export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1 +``` + +Or make it permanent + Set `~/.claude/settings.json` (or copy from `example/claude/settings.json`): ```json @@ -32,7 +48,15 @@ Set `~/.claude/settings.json` (or copy from `example/claude/settings.json`): } ``` -## 3) Verify with a direct request (optional) +## 3) Run Claude Code + +run claude code + +```shell +claude +``` + +or verify with a direct request (optional) ```bash curl http://127.0.0.1:8000/v1/messages \ diff --git a/docs/goose.md b/docs/goose.md index 57985b66..14ad1435 100644 --- a/docs/goose.md +++ b/docs/goose.md @@ -17,19 +17,17 @@ python3 -m vllm_rs.server --m Qwen/Qwen3-30B-A3B-Instruct-2507 --d 0,1 --server ## 2) Configure Goose -### Download and install Goose: https://block.github.io/goose/docs/getting-started/installation/ - ```shell # For non-UI system, export GOOSE_DISABLE_KEYRING=1 ``` - Export empty API KEY ```shell export VLLM_API_KEY="empty" ``` +### Download and install Goose: https://block.github.io/goose/docs/getting-started/installation/ ### Configure goose with `Custom Providers` and API key `empty` diff --git a/src/core/engine.rs b/src/core/engine.rs index 8dbaf567..e18e873f 100644 --- a/src/core/engine.rs +++ b/src/core/engine.rs @@ -20,7 +20,7 @@ use crate::transfer::PdRole; use crate::transfer::Transfer; use crate::utils::chat_template::Message; use crate::utils::config::{EngineConfig, EosTokenId, ModelType, SamplingParams}; -use crate::utils::guidance::load_toktrie_from_path; +use crate::utils::guidance::{build_llg_factory, load_toktrie_from_path}; use crate::utils::heartbeat::heartbeat_worker; use crate::utils::image::{get_image_config, ImageData, ImageProcessConfig}; use crate::utils::kvcache_allocator::KVCacheAllocator; @@ -106,9 +106,23 @@ impl LLMEngine { pub fn new(econfig: &EngineConfig, dtype: DType) -> Result>> { let (model_pathes, is_gguf, mut config, config_tokenizer, tokenizer, mut generation_cfg) = init_config_tokenizer(econfig)?; - let toktrie = load_toktrie_from_path(&model_pathes.get_tokenizer_filename()).map(Arc::new); + let toktrie = match load_toktrie_from_path(&model_pathes.get_tokenizer_filename()) { + Ok(trie) => Some(Arc::new(trie)), + Err(e) => { + crate::log_warn!("Failed to load tokenizer trie: {}", e); + None + } + }; + let llg_factory = match build_llg_factory(tokenizer.clone(), config.vocab_size) { + Ok(f) => Some(f), + Err(e) => { + crate::log_warn!("Failed to build llguidance factory: {}", e); + None + } + }; + if toktrie.is_none() { - crate::log_warn!("Guided decoding disabled: tokenizer trie unavailable."); + crate::log_warn!("Guided decoding (legacy) disabled: tokenizer trie unavailable."); } let stop_flag = Arc::new(AtomicBool::new(false)); @@ -196,7 +210,7 @@ impl LLMEngine { device.clone(), reporter, transfer, - toktrie.clone(), + llg_factory.clone(), None, )?; @@ -1507,4 +1521,8 @@ impl LLMEngine { pub fn get_chat_template(&self) -> ChatTemplate { self.template.clone() } + + pub fn template_supports_tools(&self) -> bool { + self.template.supports_tools() + } } diff --git a/src/core/runner.rs b/src/core/runner.rs index 99346c01..01445fa4 100644 --- a/src/core/runner.rs +++ b/src/core/runner.rs @@ -6,7 +6,7 @@ use crate::server::EmbeddingStrategy; use crate::transfer::Transfer; #[cfg(all(feature = "cuda", feature = "graph"))] use crate::utils::graph::{CudaGraphFn, CudaGraphWrapper, GraphCapturer, ModelFn}; -use crate::utils::guidance::GuidanceState; +use crate::utils::guidance::{GuidanceState, ParserFactory}; use crate::utils::image::compute_image_slice; use crate::utils::logits_processor::{LogitsProcessor, Sampling}; use crate::utils::progress::ProgressLike; @@ -28,10 +28,9 @@ use attention_rs::InputMetadata; use candle_core::{DType, Device, Result, Tensor, D}; use interprocess::local_socket::Stream as LocalStream; use parking_lot::RwLock; -use std::collections::HashMap; +use std::collections::{hash_map::Entry, HashMap, HashSet}; use std::rc::Rc; use std::sync::{Arc, Mutex, MutexGuard}; -use toktrie::TokTrie; /// Cached sampling parameters computed once during prefill, reused during decode #[derive(Clone, Debug)] @@ -82,6 +81,9 @@ pub struct ModelRunner { cached_sampling: RwLock>, seq_tokens: RwLock>>, guidance_states: RwLock>, + guidance_failed: RwLock>, + guidance_mismatch: RwLock>, + llg_factory: Option>, transfer: Option>, /// Whether this runner is on the first rank (for logging) is_first_rank: bool, @@ -101,7 +103,7 @@ impl ModelRunner { device: Device, reporter: Arc>>, transfer: Option>, - toktrie: Option>, + llg_factory: Option>, stream: Option, ) -> Result { let model = crate::build_model!( @@ -200,6 +202,30 @@ impl ModelRunner { } else { econfig.seed.unwrap() }; + let model_vocab_size = match &model { + Model::Qwen3(model) => model.get_vocab_size(), + Model::Qwen3MoE(model) => model.get_vocab_size(), + Model::LLaMa(model) => model.get_vocab_size(), + Model::Phi4(model) => model.get_vocab_size(), + Model::GLM4(model) => model.get_vocab_size(), + Model::GLM4MoE(model) => model.get_vocab_size(), + Model::Mistral3VL(model) => model.get_vocab_size(), + Model::Gemma3(model) => model.get_vocab_size(), + Model::Qwen3VL(model) => model.get_vocab_size(), + }; + + if let Some(factory) = &llg_factory { + let llg_vocab_size = factory.tok_env().tok_trie().vocab_size(); + if llg_vocab_size != model_vocab_size { + crate::log_warn!( + "llguidance vocab size {} does not match model vocab size {} for {:?}.", + llg_vocab_size, + model_vocab_size, + model_type + ); + } + } + Ok(Self { model, gpu_kv_cache: Arc::new(Mutex::new(gpu_kv_cache)), @@ -218,6 +244,9 @@ impl ModelRunner { cached_sampling: RwLock::new(None), seq_tokens: RwLock::new(HashMap::new()), guidance_states: RwLock::new(HashMap::new()), + guidance_failed: RwLock::new(HashSet::new()), + guidance_mismatch: RwLock::new(HashSet::new()), + llg_factory, transfer, is_first_rank: comm.rank() == 0, model_type, @@ -700,6 +729,104 @@ impl ModelRunner { logits.to_owned() }; + let logits = if let Some(factory) = &self.llg_factory { + let mut guidance_states = self.guidance_states.write(); + let mut guidance_failed = self.guidance_failed.write(); + let mut guidance_mismatch = self.guidance_mismatch.write(); + let mut modified = false; + let vocab_size = logits.dim(1)?; + // We only materialize logits on CPU if at least one constraint mask applies. + + // We'll collect masks first to minimize holding locks or complex logic inside the loop + let mut masks = Vec::new(); // (seq_index, seq_id, mask) + + for (i, id) in seq_ids.iter().enumerate() { + let seq_constraint = match &seqs { + Seqs::SeqRefs(refs) => &refs[i].sampling_params.constraint, + Seqs::DecodeVec(vec) => &vec[i].sampling_params.constraint, + }; + + if guidance_failed.contains(id) { + continue; + } + + if let Some(constraint) = seq_constraint { + let state = match guidance_states.entry(*id) { + Entry::Occupied(entry) => entry.into_mut(), + Entry::Vacant(entry) => { + match GuidanceState::new(factory.clone(), constraint) { + Ok(state) => entry.insert(state), + Err(err) => { + guidance_failed.insert(*id); + crate::log_warn!( + "[Seq {}] Failed to create guidance state: {}. Disabling constraints for this sequence.", + id, + err + ); + continue; + } + } + } + }; + + if let Ok(Some(mask)) = state.compute_mask() { + masks.push((i, *id, mask)); + modified = true; + } + } + } + + if modified { + // Now we must convert to Vec, modify, and update logits + let mut logits_vec = logits.flatten_all()?.to_vec1::()?; + + for (seq_idx, seq_id, mask) in masks { + let start = seq_idx * vocab_size; + let end = start + vocab_size; + let row = &mut logits_vec[start..end]; + let mask_len = mask.len(); + + // Apply mask: set disallowed to -inf + // This iterates entire vocab, but check is fast + if mask_len == 0 { + if guidance_failed.insert(seq_id) { + crate::log_warn!( + "[Seq {}] Guidance mask length is 0. Disabling constraints for this sequence.", + seq_id + ); + } + continue; + } + + if mask_len != vocab_size && guidance_mismatch.insert(seq_id) { + crate::log_warn!( + "[Seq {}] Guidance mask size {} does not match vocab size {}. Clamping mask application.", + seq_id, + mask_len, + vocab_size + ); + } + + let apply_len = std::cmp::min(vocab_size, mask_len); + for tok in 0..apply_len { + if !mask.is_allowed(tok as u32) { + row[tok] = f32::NEG_INFINITY; + } + } + if mask_len < vocab_size { + for tok in mask_len..vocab_size { + row[tok] = f32::NEG_INFINITY; + } + } + } + Tensor::from_vec(logits_vec, logits.shape(), &self.device)? + } else { + logits + } + } else { + logits + }; + let tokens = self .logit_processor .sample_with_strategy(&logits, &cached_params.sampling)?; @@ -718,6 +845,18 @@ impl ModelRunner { } } } + + // Commit tokens to guidance states + if let Some(_) = &self.llg_factory { + let mut guidance_states = self.guidance_states.write(); + for (i, id) in seq_ids.iter().enumerate() { + if let Some(state) = guidance_states.get_mut(id) { + if !state.is_finished() { + let _ = state.commit_token(tokens[i]); + } + } + } + } Ok(tokens) } @@ -726,6 +865,10 @@ impl ModelRunner { let _ = seq_tokens.remove(&id); let mut guidance_states = self.guidance_states.write(); let _ = guidance_states.remove(&id); + let mut guidance_failed = self.guidance_failed.write(); + let _ = guidance_failed.remove(&id); + let mut guidance_mismatch = self.guidance_mismatch.write(); + let _ = guidance_mismatch.remove(&id); } pub fn get_model_vocab_size(&self) -> usize { diff --git a/src/core/scheduler.rs b/src/core/scheduler.rs index d9f7e58b..aee3b71b 100644 --- a/src/core/scheduler.rs +++ b/src/core/scheduler.rs @@ -5,11 +5,11 @@ use super::{ prefix_cache::PrefixCacheConfig, sequence::{Sequence, SequenceStatus}, }; +use crate::tools::parser::prefix_could_be_tool; use crate::transfer::{PdConfig, PdRole}; use crate::utils::config::{Config, EngineConfig, EosTokenId}; use candle_core::Result; use parking_lot::RwLock; -use regex::Regex; use std::collections::VecDeque; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; @@ -24,12 +24,10 @@ pub struct Scheduler { eos_token_id: Vec, /// Token IDs that represent the end of a tool call (e.g., tokens) tool_call_end_token_ids: Vec, - /// Token ID for } character (used for JSON tool call detection) - json_end_token_id: Option, - /// Tokenizer for decoding output to check JSON tool call patterns + /// Token IDs for JSON end characters (e.g., '}' and ']') + json_end_token_ids: Vec, + /// Tokenizer for decoding output to check JSON tool call completion tokenizer: Option>, - /// Regex for detecting JSON tool calls - tool_call_regex: Regex, cfg: EngineConfig, pd_config: Option, is_last_prefill: bool, @@ -125,11 +123,8 @@ impl Scheduler { }, // Tool call end tokens will be set by engine after tokenizer is initialized tool_call_end_token_ids: Vec::new(), - json_end_token_id: None, + json_end_token_ids: Vec::new(), tokenizer: None, - // Regex to match JSON tool call format: {"name": "...", "arguments": {...}} - // We use (?s) to allow dot matching newlines - tool_call_regex: Regex::new(r#"(?s)\{\s*"name"\s*:.*"arguments"\s*:.*\}\s*$"#).unwrap(), cfg: econfig.clone(), pd_config: econfig.pd_config.clone(), is_last_prefill: false, @@ -143,13 +138,21 @@ impl Scheduler { /// Set tokenizer for JSON tool call detection (called by engine after initialization) pub fn set_tokenizer(&mut self, tokenizer: Arc) { - // Get the token ID for "}" character - if let Ok(tokens) = tokenizer.encode("}", false) { - if let Some(&token_id) = tokens.get_ids().last() { - self.json_end_token_id = Some(token_id); - crate::log_info!("JSON end token ID (}}) set to: {}", token_id); + self.json_end_token_ids.clear(); + + for ch in ["}", "]"] { + if let Ok(tokens) = tokenizer.encode(ch, false) { + if let Some(&token_id) = tokens.get_ids().last() { + if !self.json_end_token_ids.contains(&token_id) { + self.json_end_token_ids.push(token_id); + } + } } } + + if !self.json_end_token_ids.is_empty() { + crate::log_info!("JSON end token IDs set to: {:?}", self.json_end_token_ids); + } self.tokenizer = Some(tokenizer); } @@ -1007,19 +1010,24 @@ impl Scheduler { return true; } - // 2. Check for JSON style tool call using Regex - // This handles models like Qwen3 that output raw JSON without XML tags - if self.json_end_token_id == Some(token) { + // 2. Check for JSON style tool call by attempting to parse complete JSON + if self.json_end_token_ids.contains(&token) { if let Some(tokenizer) = &self.tokenizer { // Temporarily add the token to get complete output for decoding let mut temp_output = self.running[idx].output_ids.to_vec(); temp_output.push(token); if let Ok(decoded) = tokenizer.decode(&temp_output, true) { - // Check for JSON tool call pattern using Regex - // The pattern matches if the decoded string ends with a valid JSON tool call structure - if self.tool_call_regex.is_match(&decoded) { - return true; + let trimmed = decoded.trim(); + if let Ok(val) = serde_json::from_str::(trimmed) { + if val.is_object() || val.is_array() { + return true; + } + } else { + let (_could_be, is_complete) = prefix_could_be_tool(trimmed); + if is_complete { + return true; + } } } } diff --git a/src/runner/runner.rs b/src/runner/runner.rs index 366bef37..bae07b4c 100644 --- a/src/runner/runner.rs +++ b/src/runner/runner.rs @@ -7,13 +7,14 @@ use std::io::Write; use std::rc::Rc; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use tokenizers::Tokenizer; use vllm_rs::core::runner::{ModelRunner, Seqs}; use vllm_rs::models::layers::distributed::Comm; use vllm_rs::models::layers::VarBuilderX; use vllm_rs::runner::{receive_local, send_local, MessageType}; use vllm_rs::transfer::PdRole; use vllm_rs::transfer::Transfer; -use vllm_rs::utils::guidance::load_toktrie_from_path; +use vllm_rs::utils::guidance::build_llg_factory; use vllm_rs::utils::heartbeat::heartbeat_worker; use vllm_rs::utils::new_device; use vllm_rs::utils::progress::{ProgressLike, ProgressReporter, RemoteProgressReporter}; @@ -134,8 +135,15 @@ fn main() -> anyhow::Result<()> { )?; let stream_kv = Some(stream.try_clone()?); let mut econfig = init_req.econfig.clone(); - let toktrie = load_toktrie_from_path(&init_req.model_pathes.get_tokenizer_filename()) - .map(Arc::new); + let tokenizer = Tokenizer::from_file(init_req.model_pathes.get_tokenizer_filename()) + .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?; + let llg_factory = match build_llg_factory(tokenizer, init_req.config.vocab_size) { + Ok(f) => Some(f), + Err(e) => { + vllm_rs::log_warn!("Failed to build llguidance factory: {}", e); + None + } + }; #[allow(unused_mut)] let mut runner = ModelRunner::new( init_req.model_type, @@ -148,7 +156,7 @@ fn main() -> anyhow::Result<()> { device, progress_reporter, transfer, - toktrie, + llg_factory, stream_kv, )?; diff --git a/src/server/claude_server.rs b/src/server/claude_server.rs index 7756c958..1514ed0e 100644 --- a/src/server/claude_server.rs +++ b/src/server/claude_server.rs @@ -5,7 +5,9 @@ use super::{ use crate::core::engine::{LLMEngine, StreamItem}; use crate::server::logger::ChatCompletionLogger; use crate::server::parser::{ParserState, StreamResult, StreamToolParser}; -use crate::tools::helpers::{build_tool_schema_map, filter_tool_calls}; +use crate::tools::helpers::{ + build_tool_schema_map, filter_tool_calls, sanitize_tools_for_llguidance, +}; use crate::tools::parser::ToolParser; use crate::tools::{Tool, ToolCall, ToolChoice, ToolFormat}; use crate::utils::config::SamplingParams; @@ -1083,6 +1085,9 @@ pub async fn messages( } else { mcp_tools.clone() }; + if params.constraint.is_some() { + resolved_tools = sanitize_tools_for_llguidance(&resolved_tools); + } let mut tool_choice_instruction: Option = None; let mut forced_tool_name: Option = None; let mut tool_choice_required = false; @@ -1158,22 +1163,40 @@ pub async fn messages( }; let _tool_choice = tool_choice_to_openai(&request.tool_choice); - let (model_type, tool_config) = { + let (model_type, tool_config, template_supports_tools) = { let e = data.engine.read(); - (e.model_type.clone(), e.tool_config.clone()) + ( + e.model_type.clone(), + e.tool_config.clone(), + e.template_supports_tools(), + ) }; if !resolved_tools.is_empty() { - let tool_prompt_template = data.engine.read().econfig.tool_prompt_template.clone(); - let mut tool_prompt = if let Some(template) = tool_prompt_template { - template + let mut tool_prompt: Option = None; + if !template_supports_tools { + let tool_prompt_template = data.engine.read().econfig.tool_prompt_template.clone(); + let mut prompt = if let Some(template) = tool_prompt_template { + template + } else { + ToolFormat::get_tool_prompt(&model_type) + }; + if let Some(instruction) = tool_choice_instruction.as_ref() { + prompt = format!("{prompt}\n\n{instruction}"); + } + tool_prompt = Some(prompt); + } + + let instruction_only = tool_prompt.is_none() && tool_choice_instruction.is_some(); + let system_injection = if instruction_only { + tool_choice_instruction.clone() } else { - ToolFormat::get_tool_prompt(&model_type) + tool_prompt }; - if let Some(instruction) = tool_choice_instruction.as_ref() { - tool_prompt = format!("{tool_prompt}\n\n{instruction}"); + + if let Some(system_text) = system_injection { + inject_tool_prompt(&mut chat_messages, &system_text); } - inject_tool_prompt(&mut chat_messages, &tool_prompt); } let img_cfg = { @@ -1483,6 +1506,43 @@ pub async fn messages( } else { let buffer = tool_parser.take_buffer(); if !buffer.is_empty() { + let (could_be_tool, complete) = + crate::tools::parser::prefix_could_be_tool(&buffer); + if could_be_tool && !complete { + let snippet: String = + buffer.chars().take(120).collect(); + crate::log_warn!( + "[Seq {}] Incomplete tool call at stream end, dropping {} chars: {}", + seq_id, + buffer.len(), + snippet + ); + } else { + let _ = send_text_with_start( + &stream_ctx, + &mut text_block_started, + text_block_index, + &buffer, + ); + } + } + } + } + ParserState::MaybeStart => { + let buffer = tool_parser.take_buffer(); + if !buffer.is_empty() { + let (could_be_tool, complete) = + crate::tools::parser::prefix_could_be_tool(&buffer); + if could_be_tool && !complete { + let snippet: String = + buffer.chars().take(120).collect(); + crate::log_warn!( + "[Seq {}] Incomplete tool call prefix at stream end, dropping {} chars: {}", + seq_id, + buffer.len(), + snippet + ); + } else { let _ = send_text_with_start( &stream_ctx, &mut text_block_started, @@ -1492,17 +1552,6 @@ pub async fn messages( } } } - ParserState::MaybeStart => { - let buffer = tool_parser.take_buffer(); - if !buffer.is_empty() { - let _ = send_text_with_start( - &stream_ctx, - &mut text_block_started, - text_block_index, - &buffer, - ); - } - } ParserState::Normal => {} } } diff --git a/src/server/mod.rs b/src/server/mod.rs index 3b8d867e..5810baa7 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -7,9 +7,10 @@ pub mod server; pub mod streaming; use crate::core::engine::LLMEngine; use crate::server::streaming::Streamer; +use crate::tools::schema::{build_tool_call_lark_grammar, sanitize_schema_for_llguidance}; use crate::transfer::PdRole; use crate::utils::chat_template::Message; -use crate::utils::config::EngineConfig; +use crate::utils::config::{Constraint, EngineConfig}; use crate::utils::image::{ compute_tokens_per_image, get_tensor_raw_data, load_image_from_base64, load_image_from_url, ImageData, ImageProcessConfig, ImageProcessTrait, IMAGE_PLACEHOLDER, @@ -24,11 +25,48 @@ use colored::*; use local_ip_address::local_ip; use parking_lot::RwLock; use rustchatui::start_ui_server; -use serde_json::json; +use serde_json::{json, Value}; use std::collections::HashMap; use std::sync::Arc; use tower_http::cors::{Any, CorsLayer}; +#[derive(Debug, Deserialize, Serialize, Clone, Default)] +pub struct StructuredOutputs { + #[serde(default)] + pub choice: Option>, + #[serde(default)] + pub regex: Option, + #[serde(default)] + pub json: Option, + #[serde(default)] + pub grammar: Option, + #[serde(default)] + pub structural_tag: Option, +} + +#[derive(Debug, Deserialize, Serialize, Clone, Default)] +pub struct ExtraBody { + #[serde(default)] + pub structured_outputs: Option, + #[serde(flatten)] + pub extra: HashMap, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct ResponseFormatJsonSchema { + #[serde(default)] + pub name: Option, + pub schema: serde_json::Value, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct ResponseFormat { + #[serde(rename = "type")] + pub format_type: String, + #[serde(default)] + pub json_schema: Option, +} + #[derive(Debug, Deserialize, Serialize)] pub struct ChatCompletionRequest { pub messages: Vec, @@ -41,6 +79,8 @@ pub struct ChatCompletionRequest { pub presence_penalty: Option, #[serde(alias = "enable_thinking")] pub thinking: Option, + #[serde(default, alias = "stop_sequences")] + pub stop: Option>, pub stream: Option, pub session_id: Option, /// Tools available for the model to call @@ -49,6 +89,203 @@ pub struct ChatCompletionRequest { /// How the model should choose which tool to call #[serde(default)] pub tool_choice: Option, + /// OpenAI-style response format for structured outputs + #[serde(default)] + pub response_format: Option, + /// Extra body for OpenAI-compatible clients (e.g. structured_outputs) + #[serde(default)] + pub extra_body: Option, +} + +fn lark_quote(value: &str) -> String { + let escaped = value.replace('\\', "\\\\").replace('"', "\\\""); + format!("\"{}\"", escaped) +} + +fn build_choice_lark_grammar(choices: &[String]) -> Result { + if choices.is_empty() { + candle_core::bail!( + "{}", + "structured_outputs.choice must include at least one option".to_string() + ); + } + + let mut parts = Vec::with_capacity(choices.len()); + for choice in choices { + if choice.is_empty() { + candle_core::bail!( + "{}", + "structured_outputs.choice cannot contain empty strings".to_string() + ); + } + parts.push(lark_quote(choice)); + } + + Ok(format!("start: {}\n", parts.join(" | "))) +} + +fn normalize_tag_pair(tag: &str) -> Result<(String, String)> { + let trimmed = tag.trim(); + if trimmed.is_empty() { + candle_core::bail!( + "{}", + "structured_outputs.structural_tag.tag cannot be empty".to_string() + ); + } + + if trimmed.starts_with('<') && trimmed.ends_with('>') { + let inner = trimmed + .trim_start_matches('<') + .trim_end_matches('>') + .trim_start_matches('/'); + if inner.is_empty() { + candle_core::bail!( + "{}", + "structured_outputs.structural_tag.tag is invalid".to_string() + ); + } + let start = if trimmed.starts_with("", inner) + } else { + trimmed.to_string() + }; + let end = format!("", inner); + Ok((start, end)) + } else { + Ok((format!("<{}>", trimmed), format!("", trimmed))) + } +} + +fn parse_structural_tag(value: &Value) -> Result<(String, String, Value)> { + let obj = value.as_object().ok_or_else(|| { + candle_core::Error::Msg("structured_outputs.structural_tag must be an object".to_string()) + })?; + + let schema = obj.get("schema").cloned().ok_or_else(|| { + candle_core::Error::Msg("structured_outputs.structural_tag.schema is required".to_string()) + })?; + + let start = obj + .get("start_tag") + .or_else(|| obj.get("start")) + .or_else(|| obj.get("tag")); + let end = obj.get("end_tag").or_else(|| obj.get("end")); + + let (start_tag, end_tag) = match (start, end) { + (Some(start), Some(end)) => { + let start = start.as_str().ok_or_else(|| { + candle_core::Error::Msg( + "structured_outputs.structural_tag.start_tag must be a string".to_string(), + ) + })?; + let end = end.as_str().ok_or_else(|| { + candle_core::Error::Msg( + "structured_outputs.structural_tag.end_tag must be a string".to_string(), + ) + })?; + (start.to_string(), end.to_string()) + } + (Some(tag), None) if obj.contains_key("tag") => { + let tag = tag.as_str().ok_or_else(|| { + candle_core::Error::Msg( + "structured_outputs.structural_tag.tag must be a string".to_string(), + ) + })?; + normalize_tag_pair(tag)? + } + (Some(_), None) => { + candle_core::bail!( + "{}", + "structured_outputs.structural_tag.end_tag is required when start_tag is provided" + .to_string(), + ); + } + _ => { + candle_core::bail!( + "{}", + "structured_outputs.structural_tag requires tag or start_tag/end_tag".to_string(), + ); + } + }; + + Ok((start_tag, end_tag, schema)) +} + +fn sanitize_json_schema(schema: &Value, context: &str) -> Result { + if !schema.is_object() { + candle_core::bail!("{} must be a JSON Schema object", context); + } + Ok(sanitize_schema_for_llguidance(schema)) +} + +pub(crate) fn constraint_from_structured_outputs( + structured: &StructuredOutputs, +) -> Result> { + let mut selected: Option = None; + + let mut set_constraint = |constraint: Constraint| -> Result<()> { + if selected.is_some() { + candle_core::bail!("{}", + "structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag" + .to_string(), + ); + } + selected = Some(constraint); + Ok(()) + }; + + if let Some(choice) = &structured.choice { + let grammar = build_choice_lark_grammar(choice)?; + set_constraint(Constraint::Lark(grammar))?; + } + + if let Some(regex) = &structured.regex { + set_constraint(Constraint::Regex(regex.clone()))?; + } + + if let Some(schema) = &structured.json { + let schema = sanitize_json_schema(schema, "structured_outputs.json")?; + set_constraint(Constraint::JsonSchema(schema))?; + } + + if let Some(grammar) = &structured.grammar { + set_constraint(Constraint::Lark(grammar.clone()))?; + } + + if let Some(tag) = &structured.structural_tag { + let (start, end, schema) = parse_structural_tag(tag)?; + let schema = sanitize_json_schema(&schema, "structured_outputs.structural_tag.schema")?; + let grammar = build_tool_call_lark_grammar(&schema, &start, &end, false, false); + set_constraint(Constraint::Lark(grammar))?; + } + + if selected.is_none() { + candle_core::bail!("{}", + "structured_outputs must set exactly one of choice, regex, json, grammar, or structural_tag" + .to_string(), + ); + } + + Ok(selected) +} + +pub(crate) fn constraint_from_response_format( + response_format: &ResponseFormat, +) -> Result> { + match response_format.format_type.as_str() { + "json_schema" => { + let Some(schema) = response_format.json_schema.as_ref() else { + candle_core::bail!("{}", "response_format.json_schema is required".to_string()); + }; + let schema = + sanitize_json_schema(&schema.schema, "response_format.json_schema.schema")?; + Ok(Some(Constraint::JsonSchema(schema))) + } + other => candle_core::bail!( + "Unsupported response_format type '{}'; only 'json_schema' is supported", + other + ), + } } #[derive(Deserialize)] diff --git a/src/server/parser.rs b/src/server/parser.rs index 7f530126..077304a6 100644 --- a/src/server/parser.rs +++ b/src/server/parser.rs @@ -3,9 +3,9 @@ //! Handles model-specific tool call tokens and formats. use crate::server::{ChatChoiceChunk, ChatCompletionChunk, Delta}; -use crate::tools::{FunctionCall, ToolCall}; +use crate::tools::parser::{parse_tool_calls_from_text, prefix_could_be_tool}; +use crate::tools::ToolCall; use crate::utils::config::ModelType; -use serde_json::Value; use std::collections::HashSet; use tokenizers::Tokenizer; @@ -40,6 +40,8 @@ pub struct ToolConfig { pub end_token_ids: HashSet, pub start_token_str: String, pub end_token_str: String, + pub start_is_special: bool, + pub end_is_special: bool, } impl ToolConfig { @@ -58,6 +60,8 @@ impl ToolConfig { end_token_ids: end_ids, start_token_str: "<|python_tag|>".to_string(), end_token_str: "<|eom_id|>".to_string(), + start_is_special: false, + end_is_special: false, } } ModelType::Qwen3 | ModelType::Qwen3MoE | ModelType::Qwen3VL => { @@ -69,6 +73,8 @@ impl ToolConfig { end_token_ids: end_ids, start_token_str: "".to_string(), end_token_str: "".to_string(), + start_is_special: false, + end_is_special: false, } } ModelType::Mistral | ModelType::Mistral3VL => { @@ -79,6 +85,8 @@ impl ToolConfig { end_token_ids: end_ids, start_token_str: "[TOOL_CALLS]".to_string(), end_token_str: "]".to_string(), + start_is_special: false, + end_is_special: false, } } ModelType::Gemma | ModelType::Gemma3 => { @@ -88,6 +96,8 @@ impl ToolConfig { end_token_ids: end_ids, start_token_str: "".to_string(), end_token_str: "".to_string(), + start_is_special: false, + end_is_special: false, } } // Phi, GLM, Yi, StableLM, DeepSeek - use Qwen format (text-only) @@ -102,6 +112,8 @@ impl ToolConfig { end_token_ids: HashSet::new(), start_token_str: "".to_string(), end_token_str: "".to_string(), + start_is_special: false, + end_is_special: false, }, } } @@ -142,6 +154,9 @@ impl ToolConfig { ); self.end_token_ids.clear(); } + + self.start_is_special = Self::is_special_token(tokenizer, &self.start_token_str); + self.end_is_special = Self::is_special_token(tokenizer, &self.end_token_str); } /// Resolve tool call end token IDs using tokenizer and the validated config. @@ -192,15 +207,39 @@ impl ToolConfig { Err(_) => false, } } + + fn is_special_token(tokenizer: &Tokenizer, text: &str) -> bool { + if text.is_empty() { + return false; + } + let encoded = match tokenizer.encode(text, false) { + Ok(encoded) => encoded, + Err(_) => return false, + }; + let ids = encoded.get_ids(); + if ids.len() != 1 { + return false; + } + let id = ids[0]; + let added = tokenizer.get_added_tokens_decoder(); + if let Some(info) = added.get(&id) { + if info.content == text + && (info.special || (info.content.starts_with('<') && info.content.ends_with('>'))) + { + return true; + } + } + false + } } /// Streaming tool parser that handles tool call detection and buffering pub struct StreamToolParser { + #[allow(dead_code)] config: ToolConfig, state: ParserState, buffer: String, model_id: String, - parse_strategy: String, // Accumulated output for final parsing accumulated_output: String, // Reasoning block tracking @@ -227,19 +266,12 @@ impl StreamToolParser { } /// Create a new parser with a pre-validated tool config - pub fn new_with_config(model_type: &ModelType, model_id: String, config: ToolConfig) -> Self { - let parse_strategy = match model_type { - ModelType::Mistral | ModelType::Mistral3VL => "mistral_list", - _ => "json", - } - .to_string(); - + pub fn new_with_config(_model_type: &ModelType, model_id: String, config: ToolConfig) -> Self { Self { config, state: ParserState::Normal, buffer: String::new(), model_id, - parse_strategy, accumulated_output: String::new(), active_reasoning_end: None, in_code_block: false, @@ -274,7 +306,7 @@ impl StreamToolParser { /// Process a single incoming token. /// Returns StreamResult indicating what action to take. - pub fn process_token(&mut self, token_id: u32, token_text: &str) -> StreamResult { + pub fn process_token(&mut self, _token_id: u32, token_text: &str) -> StreamResult { // Always accumulate self.accumulated_output.push_str(token_text); @@ -305,114 +337,42 @@ impl StreamToolParser { // Don't detect tool calls inside reasoning or code blocks if self.in_reasoning() || self.in_code_block { + if !self.buffer.is_empty() { + let flushed = self.take_buffer(); + return StreamResult::FlushBuffer(format!("{}{}", flushed, token_text)); + } return StreamResult::Content(token_text.to_string()); } - match self.state.clone() { - ParserState::Normal => { - // Check for start trigger - if self.is_start_token(token_id, token_text) { - self.state = ParserState::Buffering; - self.buffer.clear(); + self.buffer.push_str(token_text); - if let Some(pos) = token_text.find(&self.config.start_token_str) { - let before = &token_text[..pos]; - let after = &token_text[pos + self.config.start_token_str.len()..]; - if !after.is_empty() { - self.buffer.push_str(after); - } - if !before.is_empty() { - return StreamResult::Content(before.to_string()); - } - } - - crate::log_info!( - "Tool call {} ({}) found, start buffering!", - token_text, - token_id + let (could_be_tool, tool_complete) = prefix_could_be_tool(&self.buffer); + if could_be_tool || tool_complete { + self.state = ParserState::Buffering; + if tool_complete { + let mut tool_calls = + parse_tool_calls_from_text(&self.buffer, &mut self.tool_call_index); + let result = if tool_calls.is_empty() { + crate::log_error!( + "Unable to parse tool call buffer: {}\n of accumulated buffer: {}", + self.buffer, + self.accumulated_output ); - return StreamResult::Buffering; - } - - // Check for partial tag match at end of current token - if !self.config.has_start_tokens() { - if let Some((prefix, partial)) = self.split_partial_start(token_text) { - self.state = ParserState::MaybeStart; - self.buffer.clear(); - self.buffer.push_str(&partial); - return if prefix.is_empty() { - StreamResult::Buffering - } else { - StreamResult::Content(prefix) - }; - } - } - - // Normal content - StreamResult::Content(token_text.to_string()) - } - ParserState::MaybeStart => { - self.buffer.push_str(token_text); - - if let Some(tag_pos) = self.buffer.find(&self.config.start_token_str) { - let before = self.buffer[..tag_pos].to_string(); - let after = - self.buffer[tag_pos + self.config.start_token_str.len()..].to_string(); - self.buffer.clear(); - if !after.is_empty() { - self.buffer.push_str(&after); - } - self.state = ParserState::Buffering; - return if before.is_empty() { - StreamResult::Buffering - } else { - StreamResult::Content(before) - }; - } - - if self.partial_suffix_len(&self.buffer) > 0 { - return StreamResult::Buffering; - } - - // False alarm - not a tool call tag - self.state = ParserState::Normal; - let flushed = self.buffer.clone(); + StreamResult::FlushBuffer(self.buffer.clone()) + } else { + StreamResult::ToolCalls(std::mem::take(&mut tool_calls)) + }; self.buffer.clear(); - StreamResult::FlushBuffer(flushed) - } - ParserState::Buffering => { - self.buffer.push_str(token_text); - - let end_reached = self.is_end_token(token_id, token_text) - || self.buffer_has_end_tag() - || self.maybe_complete_mistral_list(); - if end_reached { - crate::log_info!( - "Tool call buffering end, reached {} ({})", - token_text, - token_id - ); - - let tool_calls = self.parse_buffer(); - let result = if tool_calls.is_empty() { - // Parse failed - return buffered content - crate::log_error!( - "Unable to parse tool call buffer: {}\n of accumulated buffer: {}", - self.buffer, - self.accumulated_output - ); - StreamResult::FlushBuffer(self.buffer.clone()) - } else { - StreamResult::ToolCalls(tool_calls) - }; - self.buffer.clear(); - self.state = ParserState::Normal; - return result; - } - - StreamResult::Buffering + self.state = ParserState::Normal; + return result; } + return StreamResult::Buffering; } + + // Not a tool call - flush buffered content + self.state = ParserState::Normal; + let flushed = std::mem::take(&mut self.buffer); + StreamResult::Content(flushed) } /// Finalize parsing when stream ends @@ -423,11 +383,12 @@ impl StreamToolParser { self.state = ParserState::Normal; return None; } - let tool_calls = self.parse_buffer(); + let mut tool_calls = + parse_tool_calls_from_text(&self.buffer, &mut self.tool_call_index); if !tool_calls.is_empty() { self.buffer.clear(); self.state = ParserState::Normal; - return Some(tool_calls); + return Some(std::mem::take(&mut tool_calls)); } // Leave buffer intact so caller can flush it. self.state = ParserState::Normal; @@ -446,273 +407,7 @@ impl StreamToolParser { std::mem::take(&mut self.buffer) } - /// Check if token/text matches start trigger - fn is_start_token(&self, id: u32, text: &str) -> bool { - // Token ID match (if available) - if self.config.has_start_tokens() { - return self.config.start_token_ids.contains(&id); - } - // Text match - text.contains(&self.config.start_token_str) - } - - /// Check if token/text matches end trigger - fn is_end_token(&self, id: u32, text: &str) -> bool { - // Token ID match (if available) - if self.config.has_end_tokens() { - return self.config.end_token_ids.contains(&id); - } - if self.parse_strategy == "mistral_list" && self.config.end_token_str == "]" { - return false; - } - // Text match - text.contains(&self.config.end_token_str) - } - - /// Parse buffered content into tool calls - fn parse_buffer(&mut self) -> Vec { - let mut clean_text = self.buffer.trim().to_string(); - if self.should_strip_end_tag() { - if let Some(pos) = clean_text.rfind(&self.config.end_token_str) { - clean_text.truncate(pos); - } - } - let mut calls = Vec::new(); - - // Strategy 1: Mistral List [ {...}, {...} ] - if self.parse_strategy == "mistral_list" && clean_text.starts_with('[') { - if let Ok(list) = serde_json::from_str::>(&clean_text) { - for item in list.iter() { - if let Some(call) = self.json_to_tool_call(item) { - calls.push(call); - } - } - } - } - // Strategy 2: Single JSON Object (Qwen, Llama, Phi) - else if let Ok(item) = serde_json::from_str::(&clean_text) { - if let Some(call) = self.json_to_tool_call(&item) { - calls.push(call); - } - } - // Strategy 3: QwenCoder XML-style function tags - else if clean_text.starts_with("") { - // Extract function name from tag - let func_start = "').unwrap_or(0); - if func_end > func_start { - let func_name = &clean_text[func_start..func_end]; - - // Find parameter section - let params_start = clean_text.find(""); - - let mut params = std::collections::HashMap::new(); - - if let (Some(start), Some(end)) = (params_start, params_end) { - let param_content = &clean_text[start..end]; - - // Parse all parameter tags - let mut pos = 0; - while pos < param_content.len() { - let param_start_tag = param_content[pos..].find("").unwrap_or(0) + key_start; - - if key_end > key_start { - let key = ¶m_content[key_start..key_end]; - - let value_start = key_end + 1; - let value_end = param_content[value_start..] - .find("") - .unwrap_or(0) - + value_start; - - if value_end > value_start { - let value = ¶m_content[value_start..value_end]; - params.insert(key.to_string(), value.trim().to_string()); - } - } - - // Move position past this parameter - let next_pos = param_content[pos..].find("").unwrap_or(0) - + pos - + "".len(); - if next_pos <= pos { - break; - } - pos = next_pos; - } - } - - if let Ok(args) = serde_json::to_string(¶ms) { - let call = ToolCall { - index: Some(self.tool_call_index), - id: format!("call_{}", uuid::Uuid::new_v4().simple()), - call_type: "function".to_string(), - function: FunctionCall { - name: func_name.to_string(), - arguments: args, - }, - }; - - self.tool_call_index += 1; - calls.push(call); - } - } - } - // Strategy 4: Repair unbalanced JSON and retry - else if let Some(repaired) = self.repair_unbalanced_json(&clean_text) { - if repaired != clean_text { - crate::log_warn!("Tool call JSON missing closing braces; attempting repair"); - } - if let Ok(item) = serde_json::from_str::(&repaired) { - if let Some(call) = self.json_to_tool_call(&item) { - calls.push(call); - } - } - } - - calls - } - fn split_partial_start(&self, text: &str) -> Option<(String, String)> { - let tag = &self.config.start_token_str; - let suffix_len = self.partial_suffix_len(text); - if suffix_len > 0 && suffix_len < tag.len() { - let prefix = text[..text.len() - suffix_len].to_string(); - let partial = text[text.len() - suffix_len..].to_string(); - return Some((prefix, partial)); - } - None - } - - fn partial_suffix_len(&self, text: &str) -> usize { - let tag = &self.config.start_token_str; - let max = std::cmp::min(tag.len(), text.len()); - for i in (1..=max).rev() { - if text.ends_with(&tag[..i]) { - return i; - } - } - 0 - } - - fn buffer_has_end_tag(&self) -> bool { - if self.config.end_token_str.is_empty() { - return false; - } - if self.config.has_end_tokens() { - return false; - } - if self.parse_strategy == "mistral_list" && self.config.end_token_str == "]" { - return false; - } - self.buffer.contains(&self.config.end_token_str) - } - - fn maybe_complete_mistral_list(&self) -> bool { - if self.parse_strategy != "mistral_list" { - return false; - } - let trimmed = self.buffer.trim(); - if !trimmed.ends_with(']') { - return false; - } - serde_json::from_str::>(trimmed).is_ok() - } - - fn should_strip_end_tag(&self) -> bool { - let end_tag = self.config.end_token_str.as_str(); - if end_tag.is_empty() { - return false; - } - if self.parse_strategy == "mistral_list" && end_tag == "]" { - return false; - } - end_tag.starts_with('<') - } - - fn repair_unbalanced_json(&self, text: &str) -> Option { - let trimmed = text.trim(); - if !(trimmed.starts_with('{') || trimmed.starts_with('[')) { - return None; - } - - let mut in_string = false; - let mut escape = false; - let mut open_braces = 0usize; - let mut close_braces = 0usize; - let mut open_brackets = 0usize; - let mut close_brackets = 0usize; - - for ch in trimmed.chars() { - if escape { - escape = false; - continue; - } - match ch { - '\\' if in_string => { - escape = true; - } - '"' => { - in_string = !in_string; - } - '{' if !in_string => open_braces += 1, - '}' if !in_string => close_braces += 1, - '[' if !in_string => open_brackets += 1, - ']' if !in_string => close_brackets += 1, - _ => {} - } - } - - if in_string { - return None; - } - if close_braces > open_braces || close_brackets > open_brackets { - return None; - } - - if open_braces == close_braces && open_brackets == close_brackets { - return None; - } - - let mut fixed = trimmed.to_string(); - if open_brackets > close_brackets { - fixed.push_str(&"]".repeat(open_brackets - close_brackets)); - } - if open_braces > close_braces { - fixed.push_str(&"}".repeat(open_braces - close_braces)); - } - Some(fixed) - } - - /// Convert JSON value to ToolCall - fn json_to_tool_call(&mut self, item: &Value) -> Option { - let name = item["name"].as_str()?.to_string(); - let arguments = if let Some(args) = item.get("arguments") { - if args.is_string() { - args.as_str().unwrap_or("{}").to_string() - } else { - args.to_string() - } - } else { - "{}".to_string() - }; - - let call = ToolCall { - index: Some(self.tool_call_index), - id: format!("call_{}", uuid::Uuid::new_v4().simple()), - call_type: "function".to_string(), - function: FunctionCall { name, arguments }, - }; - self.tool_call_index += 1; - Some(call) - } + // legacy parsing helpers removed in favor of mistral-style JSON prefix checks // --- Chunk creation helpers (for use by server.rs) --- @@ -763,6 +458,8 @@ impl StreamToolParser { } } +// legacy partial JSON checks moved to tools::parser::prefix_could_be_tool + #[cfg(test)] mod tests { use super::*; @@ -797,7 +494,7 @@ mod tests { let mut parser = StreamToolParser::new(ModelType::Qwen3, "qwen3".to_string()); // Start tag triggers buffering - match parser.process_token(151657, "") { + match parser.process_token(0, "") { StreamResult::Buffering => {} _ => panic!("Expected Buffering on start tag"), } @@ -809,7 +506,7 @@ mod tests { } // End tag triggers parsing - match parser.process_token(151658, "") { + match parser.process_token(0, "") { StreamResult::ToolCalls(calls) => { assert_eq!(calls.len(), 1); assert_eq!(calls[0].function.name, "test"); @@ -819,38 +516,39 @@ mod tests { } #[test] - fn test_parser_partial_start_text_mode() { - let mut parser = StreamToolParser::new(ModelType::Phi, "phi".to_string()); - - // Partial start tag splits across tokens - match parser.process_token(0, " {} - _ => panic!("Expected Buffering on partial start"), - } - match parser.process_token(0, "call>") { - StreamResult::Buffering => {} - _ => panic!("Expected Buffering on completed start"), - } - match parser.process_token(0, r#"{"name": "test", "arguments": {}}"#) { - StreamResult::Buffering => {} - _ => panic!("Expected Buffering"), + fn test_parser_tool_call_array() { + let mut parser = StreamToolParser::new(ModelType::Mistral, "mistral".to_string()); + let payload = + "[TOOL_CALLS][{\"name\":\"a\",\"arguments\":{}},{\"name\":\"b\",\"arguments\":{}}]"; + match parser.process_token(0, payload) { + StreamResult::ToolCalls(calls) => { + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].function.name, "a"); + assert_eq!(calls[1].function.name, "b"); + } + _ => panic!("Expected ToolCalls"), } - match parser.process_token(0, "") { + } + + #[test] + fn test_parser_function_tag_tool_call() { + let mut parser = StreamToolParser::new(ModelType::Qwen3, "qwen3".to_string()); + let payload = r#"{"bar":1}"#; + match parser.process_token(0, payload) { StreamResult::ToolCalls(calls) => { assert_eq!(calls.len(), 1); - assert_eq!(calls[0].function.name, "test"); + assert_eq!(calls[0].function.name, "my_tool"); } _ => panic!("Expected ToolCalls"), } } #[test] - fn test_parser_token_id_strict_match() { - let mut parser = StreamToolParser::new(ModelType::Qwen3, "qwen3".to_string()); + fn test_parser_non_tool_content_flushes() { + let mut parser = StreamToolParser::new(ModelType::Phi, "phi".to_string()); - // Text match should not trigger when token IDs are available - match parser.process_token(0, "") { - StreamResult::Content(text) => assert_eq!(text, ""), + match parser.process_token(0, "Hello ") { + StreamResult::Content(text) => assert_eq!(text, "Hello "), _ => panic!("Expected Content without token ID match"), } } diff --git a/src/server/server.rs b/src/server/server.rs index b1ac4085..3aae6797 100644 --- a/src/server/server.rs +++ b/src/server/server.rs @@ -7,9 +7,10 @@ use super::{ EncodingFormat, TokenizeInput, TokenizeRequest, TokenizeResponse, }; use super::{ - ChatChoice, ChatChoiceChunk, ChatCompletionChunk, ChatCompletionRequest, - ChatCompletionResponse, ChatMessage, ChatResponseMessage, Delta, EmbeddingData, - EmbeddingOutput, EmbeddingUsage, ErrorMsg, ServerData, Usage, UsageQuery, UsageResponse, + constraint_from_response_format, constraint_from_structured_outputs, ChatChoice, + ChatChoiceChunk, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, + ChatMessage, ChatResponseMessage, Delta, EmbeddingData, EmbeddingOutput, EmbeddingUsage, + ErrorMsg, ServerData, Usage, UsageQuery, UsageResponse, }; use crate::core::engine::{LLMEngine, StreamItem}; use crate::server::parser::{ParserState, StreamResult, StreamToolParser}; @@ -18,7 +19,7 @@ use crate::tools::helpers::{ }; use crate::tools::parser::ToolParser; use crate::tools::{ToolChoice, ToolFormat}; -use crate::utils::config::SamplingParams; +use crate::utils::config::{Constraint, SamplingParams}; use axum::{ extract::{Json, Query, State}, response::{sse::KeepAlive, Sse}, @@ -111,15 +112,48 @@ pub async fn chat_completion( params.presence_penalty = request.presence_penalty; params.session_id = request.session_id.clone(); params.thinking = request.thinking.clone(); - let (img_cfg, model_type, tool_config) = { + if let Some(stop_sequences) = &request.stop { + if !stop_sequences.is_empty() { + params.stop_sequences = Some(stop_sequences.clone()); + } + } + let (img_cfg, model_type, tool_config, template_supports_tools) = { let e = data.engine.read(); ( e.img_cfg.clone(), e.model_type.clone(), e.tool_config.clone(), + e.template_supports_tools(), ) }; + let mut structured_constraint: Option = None; + if let Some(response_format) = request.response_format.as_ref() { + match constraint_from_response_format(response_format) { + Ok(constraint) => structured_constraint = constraint, + Err(err) => return ChatResponder::ValidationError(format!("{:?}", err)), + } + } + if let Some(extra_body) = request.extra_body.as_ref() { + if let Some(structured) = extra_body.structured_outputs.as_ref() { + let constraint = match constraint_from_structured_outputs(structured) { + Ok(constraint) => constraint, + Err(err) => return ChatResponder::ValidationError(format!("{:?}", err)), + }; + if constraint.is_some() { + if structured_constraint.is_some() { + return ChatResponder::ValidationError( + "Cannot combine response_format with structured_outputs".to_string(), + ); + } + structured_constraint = constraint; + } + } + } + if let Some(constraint) = structured_constraint { + params.constraint = Some(constraint); + } + let mcp_tools = data .mcp_manager .as_ref() @@ -176,44 +210,54 @@ pub async fn chat_completion( let mut chat_messages = request.messages.clone(); if has_tools { - let tool_prompt_template = data.engine.read().econfig.tool_prompt_template.clone(); - let mut tool_prompt = if let Some(template) = tool_prompt_template { - template + let mut tool_prompt: Option = None; + if !template_supports_tools { + let tool_prompt_template = data.engine.read().econfig.tool_prompt_template.clone(); + let mut prompt = if let Some(template) = tool_prompt_template { + template + } else { + ToolFormat::get_tool_prompt(&model_type) + }; + if let Some(instruction) = tool_choice_instruction.as_ref() { + prompt = format!("{prompt}\n\n{instruction}"); + } + tool_prompt = Some(prompt); + } + + let instruction_only = tool_prompt.is_none() && tool_choice_instruction.is_some(); + let system_injection = if instruction_only { + tool_choice_instruction.clone() } else { - ToolFormat::get_tool_prompt(&model_type) + tool_prompt }; - if let Some(instruction) = tool_choice_instruction.as_ref() { - tool_prompt = format!("{tool_prompt}\n\n{instruction}"); - } - // Merge with existing system prompt if present, otherwise insert new one - if !chat_messages.is_empty() && chat_messages[0].role == "system" { - // Merge: tool prompt + newline + existing system content - if let Some(ref content) = chat_messages[0].content { - let existing_content = match content { - super::MessageContentType::PureText(text) => text.clone(), - super::MessageContentType::Single(item) => match item { - super::MessageContent::Text { text } => text.clone(), - _ => String::new(), - }, - super::MessageContentType::Multi(items) => items - .iter() - .filter_map(|item| match item { - super::MessageContent::Text { text } => Some(text.clone()), - _ => None, - }) - .collect::>() - .join(" "), - }; - let merged = format!("{}\n\n{}", existing_content, tool_prompt); - chat_messages[0] = ChatMessage::text("system", merged); + if let Some(system_text) = system_injection { + // Merge with existing system prompt if present, otherwise insert new one + if !chat_messages.is_empty() && chat_messages[0].role == "system" { + if let Some(ref content) = chat_messages[0].content { + let existing_content = match content { + super::MessageContentType::PureText(text) => text.clone(), + super::MessageContentType::Single(item) => match item { + super::MessageContent::Text { text } => text.clone(), + _ => String::new(), + }, + super::MessageContentType::Multi(items) => items + .iter() + .filter_map(|item| match item { + super::MessageContent::Text { text } => Some(text.clone()), + _ => None, + }) + .collect::>() + .join(" "), + }; + let merged = format!("{}\n\n{}", existing_content, system_text); + chat_messages[0] = ChatMessage::text("system", merged); + } else { + chat_messages[0] = ChatMessage::text("system", system_text); + } } else { - // System message exists but has no content, just use tool prompt - chat_messages[0] = ChatMessage::text("system", tool_prompt); + chat_messages.insert(0, ChatMessage::text("system", system_text)); } - } else { - // No existing system prompt, insert tool prompt as first message - chat_messages.insert(0, ChatMessage::text("system", tool_prompt)); } } @@ -408,27 +452,53 @@ pub async fn chat_completion( } pending_tool_calls.append(&mut parsed); } else { - // Parse failed - flush any remaining buffer as text + // Parse failed - decide whether to flush or drop incomplete tool call let buffer = tool_parser.take_buffer(); if !buffer.is_empty() { - crate::log_warn!( - "[Seq {}] Tool parse failed, flushing {} chars", - current_seq_id, - buffer.len() - ); - stream_ctx.send_token(&buffer); + let (could_be_tool, complete) = + crate::tools::parser::prefix_could_be_tool(&buffer); + if could_be_tool && !complete { + let snippet: String = + buffer.chars().take(120).collect(); + crate::log_warn!( + "[Seq {}] Incomplete tool call at stream end, dropping {} chars: {}", + current_seq_id, + buffer.len(), + snippet + ); + } else { + crate::log_warn!( + "[Seq {}] Tool parse failed, flushing {} chars", + current_seq_id, + buffer.len() + ); + stream_ctx.send_token(&buffer); + } } } } ParserState::MaybeStart => { let buffer = tool_parser.take_buffer(); if !buffer.is_empty() { - crate::log_warn!( - "[Seq {}] Tool parse partial, flushing {} chars", - current_seq_id, - buffer.len() - ); - stream_ctx.send_token(&buffer); + let (could_be_tool, complete) = + crate::tools::parser::prefix_could_be_tool(&buffer); + if could_be_tool && !complete { + let snippet: String = + buffer.chars().take(120).collect(); + crate::log_warn!( + "[Seq {}] Incomplete tool call prefix at stream end, dropping {} chars: {}", + current_seq_id, + buffer.len(), + snippet + ); + } else { + crate::log_warn!( + "[Seq {}] Tool parse partial, flushing {} chars", + current_seq_id, + buffer.len() + ); + stream_ctx.send_token(&buffer); + } } } ParserState::Normal => {} diff --git a/src/tools/helpers.rs b/src/tools/helpers.rs index 69c7a861..2533d9cf 100644 --- a/src/tools/helpers.rs +++ b/src/tools/helpers.rs @@ -3,7 +3,7 @@ //! //! These functions handle tool resolution, schema mapping, and tool call validation. -use super::schema::validate_arguments; +use super::schema::{sanitize_schema_for_llguidance, validate_arguments}; use super::{FunctionCall, Tool, ToolCall}; use serde_json::Value; use std::collections::HashMap; @@ -18,6 +18,16 @@ pub fn resolve_tools(request_tools: Option<&[Tool]>, mcp_tools: &[Tool]) -> Vec< mcp_tools.to_vec() } +pub fn sanitize_tools_for_llguidance(tools: &[Tool]) -> Vec { + tools.iter().map(sanitize_tool_schema).collect() +} + +fn sanitize_tool_schema(tool: &Tool) -> Tool { + let mut tool = tool.clone(); + tool.function.parameters = sanitize_schema_for_llguidance(&tool.function.parameters); + tool +} + /// Build a map of tool names to their parameter schemas pub fn build_tool_schema_map(tools: &[Tool]) -> HashMap { tools diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 474fbde3..2f2041f2 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -244,22 +244,45 @@ impl ToolFormat { let config = ToolConfig::for_model_type(model_type); let start_tag = &config.start_token_str; let end_tag = &config.end_token_str; - let rule = format!( - "MOST IMPORTANT INSTRUCTION, **MUST** FOLLOW: For each function call, you MUST wrap function name and arguments in {start_tag}{end_tag} tags.\n\n\ - Do NOT USE ANY code blocks. Required format:\n\ - {start_tag}\n\ - {{\"name\": \"\", \"arguments\": }}\n\ - {end_tag}\n\n\ - Rules:\n\ - - Wrap function name and arguments with {start_tag} and {end_tag} tags\n\ - - Always use the exact {start_tag}{end_tag} format shown above\n\ - - Do NOT USE ANY code blocks\n\ - - Tool-use must be placed **at the end** of your response (**AFTER REASONING**), **top-level**, and not nested within other tags.\n\ - - Always adhere to this format for the tool use to ensure proper parsing and execution.\n\ - - The \"name\" and \"arguments\" are necessary fields\n\ - - DO NOT call ANY functions that DOES NOT defined between and \n\ - - MUST FOLLOW the above instruction when using tool call!", - ); - rule + match model_type { + crate::utils::config::ModelType::Qwen3 + | crate::utils::config::ModelType::Qwen3MoE + | crate::utils::config::ModelType::Qwen3VL => { + format!( + "MOST IMPORTANT INSTRUCTION, **MUST** FOLLOW: For each function call, you MUST use the QwenCoder tool format.\n\n\ + Required format:\n\ + {start_tag}\n\ + >\n\ + >\n\ + ...\n\ + \n\ + {end_tag}\n\n\ + Rules:\n\ + - Wrap tool calls with {start_tag} and {end_tag}\n\ + - Use and tags\n\ + - Each value MUST be valid JSON (string/object/array/number/bool)\n\ + - Do NOT USE ANY code blocks\n\ + - Tool-use must be placed at the end of your response (after reasoning)\n\ + - Only call tools defined between and \n\ + - MUST FOLLOW the above instruction when using tool call!", + ) + } + _ => format!( + "MOST IMPORTANT INSTRUCTION, **MUST** FOLLOW: For each function call, you MUST wrap function name and arguments in {start_tag}{end_tag} tags.\n\n\ + Do NOT USE ANY code blocks. Required format:\n\ + {start_tag}\n\ + {{\"name\": \"\", \"arguments\": }}\n\ + {end_tag}\n\n\ + Rules:\n\ + - Wrap function name and arguments with {start_tag} and {end_tag} tags\n\ + - Always use the exact {start_tag}{end_tag} format shown above\n\ + - Do NOT USE ANY code blocks\n\ + - Tool-use must be placed **at the end** of your response (**AFTER REASONING**), **top-level**, and not nested within other tags.\n\ + - Always adhere to this format for the tool use to ensure proper parsing and execution.\n\ + - The \"name\" and \"arguments\" are necessary fields\n\ + - DO NOT call ANY functions that DOES NOT defined between and \n\ + - MUST FOLLOW the above instruction when using tool call!", + ), + } } } diff --git a/src/tools/parser.rs b/src/tools/parser.rs index 7363dc0e..50658346 100644 --- a/src/tools/parser.rs +++ b/src/tools/parser.rs @@ -5,13 +5,17 @@ use super::ToolCall; use regex::Regex; -use serde_json::Value; +use serde::de::{self, Deserializer, MapAccess, Visitor}; +use serde_json::{Map, Value}; +use std::fmt; +use std::sync::OnceLock; /// Parser for extracting tool calls from model output text #[allow(dead_code)] #[derive(Debug, Clone)] pub struct ToolParser { /// Regex patterns for different formats + #[allow(dead_code)] patterns: Vec<(String, Regex)>, } @@ -47,12 +51,18 @@ impl ToolParser { /// Parse tool calls from model output /// Only parses tool calls from the final answer (after reasoning end markers) pub fn parse(&self, text: &str) -> Vec { - let mut calls = Vec::new(); let mut call_id = 0; // Extract only the final answer portion (after reasoning ends) let final_answer = Self::extract_final_answer(text); + // Mistral-style parsing: strip wrappers and parse JSON or JSON array. + let mut calls = parse_tool_calls_from_text(&final_answer, &mut call_id); + + if !calls.is_empty() { + return calls; + } + // Try Qwen format first if let Some(qwen_calls) = self.parse_qwen_format(&final_answer, &mut call_id) { calls.extend(qwen_calls); @@ -128,9 +138,7 @@ impl ToolParser { } if let Ok(parsed) = serde_json::from_str::(trimmed) { - if let Some(call) = self.value_to_tool_call(&parsed, call_id) { - calls.push(call); - } + calls.extend(self.value_to_tool_calls(&parsed, call_id)); } } } @@ -150,8 +158,9 @@ impl ToolParser { // Simple approach: try to parse the entire text as JSON first if let Ok(parsed) = serde_json::from_str::(text.trim()) { - if let Some(call) = self.value_to_tool_call(&parsed, call_id) { - return Some(vec![call]); + let parsed_calls = self.value_to_tool_calls(&parsed, call_id); + if !parsed_calls.is_empty() { + return Some(parsed_calls); } } @@ -173,9 +182,7 @@ impl ToolParser { if let Some(s) = start { let json_str = &text[s..=i]; if let Ok(parsed) = serde_json::from_str::(json_str) { - if let Some(call) = self.value_to_tool_call(&parsed, call_id) { - calls.push(call); - } + calls.extend(self.value_to_tool_calls(&parsed, call_id)); } } start = None; @@ -200,9 +207,7 @@ impl ToolParser { for cap in re.captures_iter(text) { if let Some(content) = cap.get(1) { if let Ok(parsed) = serde_json::from_str::(content.as_str().trim()) { - if let Some(call) = self.value_to_tool_call(&parsed, call_id) { - calls.push(call); - } + calls.extend(self.value_to_tool_calls(&parsed, call_id)); } } } @@ -214,31 +219,32 @@ impl ToolParser { } } - /// Convert a JSON Value to a ToolCall if it has the right structure - fn value_to_tool_call(&self, value: &Value, call_id: &mut usize) -> Option { - let name = value.get("name")?.as_str()?; - let arguments = value.get("arguments")?; - - let args_str = if arguments.is_string() { - arguments.as_str().unwrap().to_string() - } else { - serde_json::to_string(arguments).ok()? - }; - - *call_id += 1; - Some(ToolCall::new( - format!("call_{}", call_id), - name.to_string(), - args_str, - )) + /// Convert a JSON Value to ToolCall(s) if it has the right structure + fn value_to_tool_calls(&self, value: &Value, call_id: &mut usize) -> Vec { + match value { + Value::Array(items) => items + .iter() + .flat_map(|item| self.value_to_tool_calls(item, call_id)) + .collect(), + Value::Object(_) => { + if let Some(call) = json_value_to_tool_call(value, call_id) { + vec![call] + } else { + Vec::new() + } + } + _ => Vec::new(), + } } - /// Check if text contains any tool calls (only explicit XML tags in final answer) - /// Note: Raw JSON patterns are NOT checked to avoid false positives in reasoning + /// Check if text contains any tool calls (JSON or model-specific wrappers) pub fn has_tool_calls(&self, text: &str) -> bool { let final_answer = Self::extract_final_answer(text); - // Only check for explicit XML-wrapped tool calls - final_answer.contains("") + let mut call_id = 0; + if !parse_tool_calls_from_text(&final_answer, &mut call_id).is_empty() { + return true; + } + contains_tool_call_prefix(&final_answer) } /// Check if text contains a complete, parseable tool call @@ -269,6 +275,316 @@ impl ToolParser { } } +// --- Mistral-style tool parsing helpers --- + +// Accept either `{...}` **or** a `"stringified { ... }"` +fn flexible_args<'de, D>(d: D) -> std::result::Result +where + D: Deserializer<'de>, +{ + struct ArgVisitor; + + impl<'de> Visitor<'de> for ArgVisitor { + type Value = Value; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("an object or a JSON-encoded string containing an object") + } + + fn visit_map(self, mut m: M) -> std::result::Result + where + M: MapAccess<'de>, + { + let mut map = Map::new(); + while let Some((k, v)) = m.next_entry()? { + map.insert(k, v); + } + Ok(Value::Object(map)) + } + + fn visit_str(self, s: &str) -> std::result::Result + where + E: de::Error, + { + serde_json::from_str(s).map_err(|e| E::custom(format!("inner JSON error: {e}"))) + } + } + + d.deserialize_any(ArgVisitor) +} + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +struct CalledFunctionParameters { + #[serde(alias = "function")] + name: String, + #[serde(alias = "arguments", deserialize_with = "flexible_args")] + parameters: Value, +} + +fn contains_tool_call_prefix(prefix: &str) -> bool { + prefix.contains("") + || prefix.contains("<|tool▁call▁begin|>") + || prefix.contains("<|python_tag|>") + || prefix.contains("[TOOL_CALLS]") +} + +fn process_model_specific_message(message: &str) -> String { + static DEEPSEEK_REGEX: OnceLock = OnceLock::new(); + static QWEN_REGEX: OnceLock = OnceLock::new(); + + let deepseek_regex = DEEPSEEK_REGEX.get_or_init(|| { + Regex::new( + r"(?s)<|tool▁call▁begin|>function<|tool▁sep|>(?P[^\n]+)\n```json\n(?P.+?)\n```<|tool▁call▁end|>", + ) + .unwrap() + }); + let qwen_regex = QWEN_REGEX + .get_or_init(|| Regex::new(r"(?s)(?P.*?)").unwrap()); + + if let Some(message) = message.strip_prefix("<|python_tag|>") { + message + .strip_suffix("<|eom_id|>") + .unwrap_or(message) + .to_string() + } else if qwen_regex.is_match(message) { + if let Some(caps) = qwen_regex.captures(message) { + let inner = caps.name("inner").unwrap().as_str(); + return inner.trim().to_string(); + } + message.to_string() + } else if let Some(message) = message + .strip_prefix("[TOOL_CALLS][") + .and_then(|s| s.strip_suffix("]")) + { + message.to_string() + } else if deepseek_regex.find(message).is_some() { + #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] + struct ToolCall { + name: String, + arguments: Value, + } + let mut calls = Vec::new(); + for caps in deepseek_regex.captures_iter(message) { + let name = caps + .name("name") + .map(|m| m.as_str().trim().to_string()) + .unwrap_or_default(); + let json_str = caps.name("json").map(|m| m.as_str().trim()).unwrap_or("{}"); + let arguments: Value = + serde_json::from_str(json_str).unwrap_or_else(|_| Value::Object(Map::new())); + calls.push(ToolCall { name, arguments }); + } + serde_json::to_string(&calls).unwrap_or_else(|_| message.to_string()) + } else { + message.to_string() + } +} + +fn fix_broken_json(raw: &str) -> String { + if raw.contains(r#""arguments":"{"#) { + let tmp = raw.replacen(r#""arguments":"{"#, r#""arguments":{"#, 1); + tmp.replacen(r#"}"}"#, r#"}}"#, 1) + } else { + raw.to_string() + } +} + +fn json_value_to_tool_call(value: &Value, call_id: &mut usize) -> Option { + let name = value.get("name")?.as_str()?.to_string(); + let arguments = value.get("arguments")?; + let args_str = if arguments.is_string() { + arguments.as_str().unwrap_or("{}").to_string() + } else { + serde_json::to_string(arguments).ok()? + }; + + let call = ToolCall { + index: Some(*call_id), + id: format!("call_{}", uuid::Uuid::new_v4().simple()), + call_type: "function".to_string(), + function: super::FunctionCall { + name, + arguments: args_str, + }, + }; + *call_id += 1; + Some(call) +} + +/// Parse tool calls from a raw message string (handles model-specific wrappers). +pub fn parse_tool_calls_from_text(text: &str, call_id: &mut usize) -> Vec { + // First, handle explicit wrappers (may appear multiple times) + if text.contains("") { + let mut calls = Vec::new(); + if let Ok(re) = Regex::new(r"(?s)\s*(.*?)\s*") { + for cap in re.captures_iter(text) { + if let Some(inner) = cap.get(1) { + let inner = inner.as_str().trim(); + if let Ok(parsed) = serde_json::from_str::(inner) { + if let Some(call) = json_value_to_tool_call(&parsed, call_id) { + calls.push(call); + } + continue; + } + + if let Some(call) = parse_function_tag_tool_call(inner, call_id) { + calls.push(call); + } + } + } + } + if !calls.is_empty() { + return calls; + } + } + + let processed = process_model_specific_message(text); + let processed = fix_broken_json(&processed); + + if let Ok(deser) = serde_json::from_str::(&processed) { + let args = serde_json::to_string(&deser.parameters).unwrap_or_else(|_| "{}".to_string()); + let call = ToolCall { + index: Some(*call_id), + id: format!("call_{}", uuid::Uuid::new_v4().simple()), + call_type: "function".to_string(), + function: super::FunctionCall { + name: deser.name, + arguments: args, + }, + }; + *call_id += 1; + return vec![call]; + } + + if let Ok(deser) = serde_json::from_str::>(&processed) { + let mut out = Vec::new(); + for item in deser { + let args = serde_json::to_string(&item.parameters).unwrap_or_else(|_| "{}".to_string()); + out.push(ToolCall { + index: Some(*call_id), + id: format!("call_{}", uuid::Uuid::new_v4().simple()), + call_type: "function".to_string(), + function: super::FunctionCall { + name: item.name, + arguments: args, + }, + }); + *call_id += 1; + } + return out; + } + + Vec::new() +} + +/// Checks if the given prefix could be the start of, or the entire JSON serialization of a tool call. +/// Returns (could_be_tool, is_complete_tool). +pub fn prefix_could_be_tool(prefix: &str) -> (bool, bool) { + if prefix.trim().is_empty() { + return (false, false); + } + + // If we already have a full ..., attempt to parse directly. + if prefix.contains("") { + let mut call_id = 0; + if !parse_tool_calls_from_text(prefix, &mut call_id).is_empty() { + return (false, true); + } + } + + // If we see a start tag, it's at least a potential tool call. + if prefix.contains("") { + return (true, false); + } + + let processed = process_model_specific_message(prefix); + let processed = fix_broken_json(&processed); + + let checks = [ + could_be_json::, + could_be_json::>, + ]; + + for check in checks { + let (could_be, complete) = check(&processed); + if could_be || complete { + return (could_be, complete); + } + } + + ( + contains_tool_call_prefix(prefix) || contains_tool_call_prefix(&processed), + false, + ) +} + +fn could_be_json(text_prefix: &str) -> (bool, bool) +where + T: serde::de::DeserializeOwned, +{ + if text_prefix.trim().is_empty() { + return (false, false); + } + match serde_json::from_str::(text_prefix) { + Ok(_) => (false, true), + Err(e) if e.is_eof() => (true, false), + _ => (false, false), + } +} + +fn parse_function_tag_tool_call(inner: &str, call_id: &mut usize) -> Option { + let func_tag = "')? + name_start; + if name_end <= name_start { + return None; + } + let func_name = inner[name_start..name_end].trim(); + if func_name.is_empty() { + return None; + } + + let mut params = Map::new(); + let mut pos = name_end + 1; + while let Some(param_tag_pos) = inner[pos..].find("") + .map(|v| v + value_start)?; + if value_end <= value_start { + break; + } + let value_raw = inner[value_start..value_end].trim(); + let value = serde_json::from_str::(value_raw) + .unwrap_or_else(|_| Value::String(value_raw.to_string())); + params.insert(key.to_string(), value); + pos = value_end + "".len(); + } + + let args = Value::Object(params); + let args_str = serde_json::to_string(&args).ok()?; + + let call = ToolCall { + index: Some(*call_id), + id: format!("call_{}", uuid::Uuid::new_v4().simple()), + call_type: "function".to_string(), + function: super::FunctionCall { + name: func_name.to_string(), + arguments: args_str, + }, + }; + *call_id += 1; + Some(call) +} + #[cfg(test)] mod tests { use super::*; @@ -333,4 +649,25 @@ mod tests { assert!(parser.has_tool_calls(r#"{"name": "foo", "arguments": {}}"#)); assert!(!parser.has_tool_calls("Just a normal response")); } + + #[test] + fn test_parse_function_tag_format() { + let parser = ToolParser::new(); + let text = r#" + + +{"bar": 1} + + +qux + + +"#; + + let calls = parser.parse(text); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].function.name, "my_tool"); + assert!(calls[0].function.arguments.contains("\"foo\"")); + assert!(calls[0].function.arguments.contains("\"baz\"")); + } } diff --git a/src/tools/schema.rs b/src/tools/schema.rs index 8a4f25c6..81e8838b 100644 --- a/src/tools/schema.rs +++ b/src/tools/schema.rs @@ -3,8 +3,29 @@ //! //! Provides helpers for working with JSON Schema in tool definitions. -use serde_json::{json, Value}; +use crate::tools::Tool; +use serde_json::{json, Map, Value}; use std::collections::HashMap; +/// Remove JSON Schema features that llguidance doesn't support. +/// Currently strips all "format" fields recursively. +pub fn sanitize_schema_for_llguidance(schema: &Value) -> Value { + match schema { + Value::Object(map) => { + let mut out = Map::new(); + for (key, value) in map { + if key == "format" { + continue; + } + out.insert(key.clone(), sanitize_schema_for_llguidance(value)); + } + Value::Object(out) + } + Value::Array(items) => { + Value::Array(items.iter().map(sanitize_schema_for_llguidance).collect()) + } + _ => schema.clone(), + } +} /// Builder for creating JSON Schema objects #[derive(Debug, Clone, Default)] @@ -240,6 +261,203 @@ pub fn validate_arguments(schema: &Value, arguments: &Value) -> Result<(), Strin Ok(()) } +/// Build a JSON Schema for tool calls. +/// Supports a single tool call object or an array of tool call objects. +pub fn build_tool_call_schema(tools: &[Tool]) -> Value { + let mut variants = Vec::new(); + + for tool in tools { + let name = tool.function.name.clone(); + let mut args_schema = tool.function.parameters.clone(); + + // If strict mode is requested and schema is object-like, disallow extra properties. + if tool.function.strict.unwrap_or(false) { + if args_schema.get("type") == Some(&Value::String("object".to_string())) + && args_schema.get("additionalProperties").is_none() + { + args_schema["additionalProperties"] = Value::Bool(false); + } + } + + let variant = json!({ + "type": "object", + "properties": { + "name": { "const": name }, + "arguments": args_schema + }, + "required": ["name", "arguments"], + "additionalProperties": false + }); + variants.push(variant); + } + + let tool_call_schema = if variants.len() == 1 { + variants.into_iter().next().unwrap_or_else(|| json!({})) + } else { + json!({ "oneOf": variants }) + }; + + json!({ + "oneOf": [ + tool_call_schema, + { + "type": "array", + "items": tool_call_schema, + "minItems": 1 + } + ] + }) +} + +fn lark_quote(value: &str) -> String { + let escaped = value.replace('\\', "\\\\").replace('"', "\\\""); + format!("\"{}\"", escaped) +} + +fn lark_literal(value: &str, is_special: bool) -> String { + if is_special && value.starts_with('<') && value.ends_with('>') { + value.to_string() + } else { + lark_quote(value) + } +} + +/// Build a Lark grammar that wraps a tool call JSON schema between start/end markers. +pub fn build_tool_call_lark_grammar( + schema: &Value, + start: &str, + end: &str, + start_is_special: bool, + end_is_special: bool, +) -> String { + let schema_json = serde_json::to_string(schema).unwrap_or_else(|_| "{}".to_string()); + + if start.is_empty() || end.is_empty() { + return format!("start: tool\ntool: %json {schema_json}\n"); + } + + let start_lit = lark_literal(start, start_is_special); + let end_lit = lark_literal(end, end_is_special); + + format!( + "start: {start_lit} _WS? tool _WS? {end_lit}\n\ + tool: %json {schema_json}\n\ + _WS: /[ \\t\\r\\n]+/\n" + ) +} + +/// Build a Lark grammar for QwenCoder-style function/parameter tags with JSON values. +pub fn build_function_tag_lark_grammar( + tools: &[Tool], + start: &str, + end: &str, + start_is_special: bool, + end_is_special: bool, +) -> String { + let mut rules: Vec = Vec::new(); + let start_tag = if start.is_empty() { + None + } else { + Some(lark_literal(start, start_is_special)) + }; + let end_tag = if end.is_empty() { + None + } else { + Some(lark_literal(end, end_is_special)) + }; + + let tool_rule_names: Vec = (0..tools.len()).map(|i| format!("tool_{i}")).collect(); + let toolcall_rule = if tool_rule_names.is_empty() { + "toolcall:".to_string() + } else { + format!("toolcall: {}", tool_rule_names.join(" | ")) + }; + + if let (Some(start_lit), Some(end_lit)) = (start_tag.as_ref(), end_tag.as_ref()) { + rules.push(format!("start: {start_lit} _WS? toolcall _WS? {end_lit}")); + } else { + rules.push("start: toolcall".to_string()); + } + + rules.push(toolcall_rule); + + for (tool_idx, tool) in tools.iter().enumerate() { + let func_start = lark_quote(&format!("", tool.function.name)); + let func_end = lark_quote(""); + + let params_schema = &tool.function.parameters; + let props = params_schema.get("properties").and_then(|p| p.as_object()); + let defs = params_schema.get("$defs").cloned(); + let definitions = params_schema.get("definitions").cloned(); + + let mut param_rule_names = Vec::new(); + + if let Some(props) = props { + for (param_idx, (param_name, schema)) in props.iter().enumerate() { + let param_tag = lark_quote(&format!("", param_name)); + let param_end = lark_quote(""); + let value_rule = format!("value_{tool_idx}_{param_idx}"); + let param_rule = format!("param_{tool_idx}_{param_idx}"); + let schema_with_defs = if defs.is_some() || definitions.is_some() { + if let Some(obj) = schema.as_object() { + let mut merged = obj.clone(); + // Preserve existing $defs/definitions in param schema if present. + if let Some(ref defs_val) = defs { + merged + .entry("$defs".to_string()) + .or_insert_with(|| defs_val.clone()); + } + if let Some(ref defs_val) = definitions { + merged + .entry("definitions".to_string()) + .or_insert_with(|| defs_val.clone()); + } + serde_json::Value::Object(merged) + } else { + let mut merged = serde_json::Map::new(); + merged.insert("allOf".to_string(), json!([schema.clone()])); + if let Some(ref defs_val) = defs { + merged.insert("$defs".to_string(), defs_val.clone()); + } + if let Some(ref defs_val) = definitions { + merged.insert("definitions".to_string(), defs_val.clone()); + } + serde_json::Value::Object(merged) + } + } else { + schema.clone() + }; + + let schema_json = + serde_json::to_string(&schema_with_defs).unwrap_or_else(|_| "{}".to_string()); + + rules.push(format!("{value_rule}: %json {schema_json}")); + rules.push(format!( + "{param_rule}: {param_tag} {value_rule} {param_end}" + )); + param_rule_names.push(param_rule); + } + } + + let params_expr = if param_rule_names.is_empty() { + String::new() + } else { + format!("({})*", param_rule_names.join(" | ")) + }; + + if params_expr.is_empty() { + rules.push(format!("tool_{tool_idx}: {func_start} _WS? {func_end}")); + } else { + rules.push(format!( + "tool_{tool_idx}: {func_start} _WS? {params_expr} _WS? {func_end}" + )); + } + } + + rules.push("_WS: /[ \\t\\r\\n]+/".to_string()); + rules.join("\n") +} + fn validate_type(schema: &Value, value: &Value, field_name: &str) -> Result<(), String> { if let Some(enum_values) = schema.get("enum").and_then(|v| v.as_array()) { if !enum_values.iter().any(|v| v == value) { @@ -473,4 +691,20 @@ mod tests { assert!(validate_arguments(&schema, &valid).is_ok()); assert!(validate_arguments(&schema, &invalid).is_err()); } + + #[test] + fn test_sanitize_schema_strips_format() { + let schema = json!({ + "type": "object", + "properties": { + "url": {"type": "string", "format": "uri"}, + "nested": {"type": "object", "properties": {"id": {"type": "string", "format": "uuid"}}} + } + }); + let sanitized = sanitize_schema_for_llguidance(&schema); + assert!(sanitized["properties"]["url"].get("format").is_none()); + assert!(sanitized["properties"]["nested"]["properties"]["id"] + .get("format") + .is_none()); + } } diff --git a/src/utils/chat_template.rs b/src/utils/chat_template.rs index 8f5e06a3..e4833849 100644 --- a/src/utils/chat_template.rs +++ b/src/utils/chat_template.rs @@ -103,6 +103,19 @@ impl ChatTemplate { self.enable_thinking = enable; } + pub fn supports_tools(&self) -> bool { + let Some(template) = &self.chat_template else { + return false; + }; + let lower = template.to_lowercase(); + lower.contains("tools") + || lower.contains("tool_calls") + || lower.contains("[available_tools]") + || lower.contains("") + } + #[allow(dead_code)] fn clear_message(&mut self) { self.messages.clear() diff --git a/src/utils/config.rs b/src/utils/config.rs index 47bcc47d..1f5529d4 100644 --- a/src/utils/config.rs +++ b/src/utils/config.rs @@ -1,10 +1,20 @@ // src/utils/config.rs use crate::transfer::PdConfig; +use llguidance::api::TopLevelGrammar; #[cfg(feature = "python")] use pyo3::pyclass; use serde::de::value::SeqAccessDeserializer; use serde::de::{Deserializer, Visitor}; use serde::{Deserialize, Serialize, Serializer}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum Constraint { + Regex(String), + Lark(String), + JsonSchema(serde_json::Value), + Llguidance(TopLevelGrammar), + None, +} use std::collections::HashMap; use std::fmt; @@ -431,6 +441,8 @@ pub struct SamplingParams { /// If Some(true), external tools are enabled and stream finishes at . #[serde(default)] pub mcp_mode: Option, + #[serde(default)] + pub constraint: Option, } #[cfg(feature = "python")] @@ -465,6 +477,8 @@ pub struct SamplingParams { #[pyo3(get, set)] #[serde(alias = "enable_thinking")] pub thinking: Option, + #[serde(default)] + pub constraint: Option, } #[cfg(not(feature = "python"))] @@ -492,6 +506,7 @@ impl SamplingParams { mcp_mode: None, stop_sequences: None, stop_token_ids: None, + constraint: None, thinking, } } @@ -509,6 +524,7 @@ impl SamplingParams { mcp_mode: None, stop_sequences: None, stop_token_ids: None, + constraint: None, thinking: None, } } @@ -528,6 +544,7 @@ impl Default for SamplingParams { mcp_mode: None, stop_sequences: None, stop_token_ids: None, + constraint: None, thinking: None, } } diff --git a/src/utils/guidance.rs b/src/utils/guidance.rs index e37dbad8..aa446326 100644 --- a/src/utils/guidance.rs +++ b/src/utils/guidance.rs @@ -1,54 +1,93 @@ // src/utils/guidance.rs -//! Guided decoding support via llguidance. -//! -//! NOTE: This module is currently stubbed out due to API changes in llguidance >= 0.6. -//! The TopLevelGrammar::from_json_schema method is no longer available. -//! Guided decoding features are temporarily disabled. - -use serde_json::Value; -use std::path::Path; +use anyhow::Result; +use llguidance::{api::TopLevelGrammar, Matcher, ParserFactory as LlgParserFactory}; use std::sync::Arc; +use tokenizers::Tokenizer; +use toktrie::{SimpleVob, TokTrie}; +use toktrie_hf_tokenizers::{ByteTokenizer, ByteTokenizerEnv}; -// Import toktrie from the crate root (it's re-exported by llguidance) -pub use toktrie::TokTrie; +use crate::utils::config::Constraint; pub struct GuidanceState { - // Placeholder for future implementation - _phantom: std::marker::PhantomData<()>, + matcher: Matcher, } impl GuidanceState { - pub fn new(_toktrie: Arc, _schema: Value) -> anyhow::Result { - // Stubbed out - guided decoding temporarily disabled - anyhow::bail!("Guided decoding is temporarily disabled due to llguidance API changes. \ - The TopLevelGrammar::from_json_schema method is no longer available in llguidance >= 0.6") + pub fn new(factory: Arc, constraint: &Constraint) -> Result { + let grammar = llg_grammar_from_constraint(constraint)?; + let grammar = match grammar { + Some(g) => g, + None => { + // If None, we probably shouldn't be creating a GuidanceState, or we create a dummy one + // But generally the caller guards this. + // For now, let's error if called with None, or we can handle it. + // Actually, let's support it if needed, but for now strict. + anyhow::bail!("Cannot create GuidanceState from Constraint::None"); + } + }; + + let parser = factory.create_parser(grammar)?; + let matcher = Matcher::new(Ok(parser)); + Ok(Self { matcher }) + } + + pub fn compute_mask(&mut self) -> Result> { + if self.matcher.is_stopped() { + return Ok(None); + } + // compute_mask returns a standard bitmask or list of tokens + self.matcher.compute_mask().map(Some).map_err(Into::into) } - pub fn compute_allowed_tokens(&mut self) -> anyhow::Result { - anyhow::bail!("Guided decoding is temporarily disabled") + pub fn commit_token(&mut self, token: u32) -> Result<()> { + if !self.matcher.is_stopped() { + self.matcher.consume_token(token)?; + } + Ok(()) } - pub fn commit_token(&mut self, _token: u32) -> anyhow::Result<()> { - anyhow::bail!("Guided decoding is temporarily disabled") + pub fn is_finished(&self) -> bool { + self.matcher.is_stopped() } } -pub struct AllowedTokens { - pub tokens: Vec, - pub is_stopped: bool, +pub type ParserFactory = LlgParserFactory; + +pub fn build_llg_factory( + tokenizer: Tokenizer, + vocab_size: Option, +) -> Result> { + let tokenizer_vocab = tokenizer.get_vocab_size(true); + let target_vocab = vocab_size.map(|v| { + if v < tokenizer_vocab { + crate::log_warn!( + "Requested vocab size {} is smaller than tokenizer vocab size {}. Using tokenizer size.", + v, + tokenizer_vocab + ); + tokenizer_vocab + } else { + v + } + }); + let env = ByteTokenizer::from_tokenizer(tokenizer)?.into_tok_env(target_vocab)?; + let factory = ParserFactory::new_simple(&env)?; + Ok(Arc::new(factory)) } -pub fn build_toktrie_from_tokenizer_bytes(bytes: &[u8]) -> anyhow::Result { - // Try to build TokTrie from bytes - // The new API uses TokTrie::from() with TokRxInfo and words - // For now, return an error as the exact migration path needs investigation - anyhow::bail!("TokTrie construction from tokenizer bytes is temporarily disabled. \ - The TokTrie::from_huggingface_bytes method is no longer available in toktrie >= 1.0. \ - Input bytes length: {}", bytes.len()) +pub fn load_toktrie_from_path(path: impl AsRef) -> Result { + let tokenizer = ByteTokenizer::from_file(path)?; + let env = ByteTokenizerEnv::new(tokenizer, None)?; + Ok(env.tok_trie) } -pub fn load_toktrie_from_path(_: &Path) -> Option { - // Temporarily disabled - returns None - // crate::log_warn!("load_toktrie_from_path is disabled: {:?}", path); - None +pub fn llg_grammar_from_constraint(constraint: &Constraint) -> Result> { + let grm = match constraint { + Constraint::Regex(regex) => TopLevelGrammar::from_regex(regex), + Constraint::Lark(lark) => TopLevelGrammar::from_lark(lark.clone()), + Constraint::JsonSchema(value) => TopLevelGrammar::from_json_schema(value.clone()), + Constraint::Llguidance(value) => value.clone(), + Constraint::None => return Ok(None), + }; + Ok(Some(grm)) }