diff --git a/Cargo.toml b/Cargo.toml index da3162c8..25729e79 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vllm-rs" -version = "0.8.10" +version = "0.8.11" edition = "2021" default-run = "vllm-rs" @@ -61,6 +61,8 @@ bytemuck = "1.24.0" regex = "1.12.2" local-ip-address = "0.6.5" url = "2.5.7" +tool-parser = "1.0" +openai-protocol = "1.0" [lib] name = "vllm_rs" diff --git a/ReadMe-CN.md b/ReadMe-CN.md index 3be70d2f..898cf3c6 100644 --- a/ReadMe-CN.md +++ b/ReadMe-CN.md @@ -77,6 +77,7 @@ ## 📚 文档 - [快速开始](docs/get_started.md) - [Docker构建](docs/docker.md) +- [工具调用解析](docs/tool_parsing.md) - [MCP集成与工具调用](docs/mcp_tool_calling.md) - [Claude Code使用vLLM.rs后端](docs/claude_code.md) - [Goose AI Agent使用vLLM.rs后端](docs/goose.md) diff --git a/ReadMe.md b/ReadMe.md index 9eb245d0..98a6cacc 100644 --- a/ReadMe.md +++ b/ReadMe.md @@ -78,6 +78,7 @@ All models support hardware FP8 KV-cache acceleration (requires SM90+ and disabl ## 📚 Guides - [Get Started](docs/get_started.md) - [Docker Build](docs/docker.md) +- [Tool Parsing](docs/tool_parsing.md) - [MCP Integration and Tool Calling](docs/mcp_tool_calling.md) - [Work with Claude Code](docs/claude_code.md) - [Work with Goose AI Agent](docs/goose.md) diff --git a/docs/tool_parsing.md b/docs/tool_parsing.md new file mode 100644 index 00000000..cd835d4f --- /dev/null +++ b/docs/tool_parsing.md @@ -0,0 +1,74 @@ +## Tool Call Parsing + +This project uses the model-specific parsers for parsing tool calls for both +streaming and non-streaming responses. The goal is to keep parsing logic +consistent across models while remaining robust to partial output and format +differences. + +### Parser selection + +Parser selection follows this order: + +1. `--enforce-parser` (if provided and valid). +2. Model-based heuristics (model type + model id). +3. Fallback to `passthrough`. + +If you pass an invalid name to `--enforce-parser`, the server returns an error +and includes the list of valid parser names. + +Available parser names: + +- passthrough +- json +- mistral +- qwen +- qwen_coder +- pythonic +- llama +- deepseek +- glm45_moe +- glm47_moe +- step3 +- kimik2 +- minimax_m2 + +### Streaming parsing + +Streaming requests use incremental parsing logic. Each incoming token is appended to an internal buffer and fed into the parser. The parser emits streaming tool call fragments, which are accumulated into full tool calls. + +When an end marker is detected (token id or `` tag), the stream parser: + +1. Flushes any unstreamed arguments from the external parser. +2. Builds tool calls from the accumulated fragments. +3. If no tool calls were produced, falls back to `parse_complete_with_fallback` + on the buffered content. + +If parsing still fails, the buffered content is emitted as normal text so the +client does not lose output. + +### Non-streaming parsing + +Non-streaming requests reuse the same stream parser and call +`parse_complete_with_fallback`. This keeps parser selection and fallback logic +identical between streaming and non-streaming paths. + +### Enforcing a parser + +CLI (Rust server): + +``` +--enforce-parser qwen_coder +``` + +Python server example (`server.py` or `vllm_rs.server`): + +``` +--enforce-parser qwen_coder +``` + +### Environment Variables + +- `VLLM_RS_STRICT_TOOL_CALL`: + - `1` or `true`: Strict validation. Dropping invalid tool calls (calls that do not match the schema) effectively preventing them from being sent to the client. The server logs a warning for dropped calls. + - `0` or `false` (default): Lenient validation. Invalid tool calls are kept and sent to the client, but a warning is logged by the server. This allows models to output "hallucinated" or malformed calls if desired. + diff --git a/example/server.py b/example/server.py index 667db0f5..2a8359ac 100644 --- a/example/server.py +++ b/example/server.py @@ -37,12 +37,39 @@ def parse_args(): parser.add_argument("--mcp_config", type=str, default=None) parser.add_argument("--mcp_command", type=str, default=None) parser.add_argument("--mcp_args", type=str, default=None) + parser.add_argument("--enforce-parser", type=str, default=None) parser.add_argument("--pd-server-prefix-cache-ratio", type=float, default=None) parser.add_argument("--pd-client-prefix-cache-ratio", type=float, default=None) args = parser.parse_args() if args.pd_server and args.ui_server: raise ValueError("PD Server cannot run with UI Server enabled!") + if args.enforce_parser is not None: + enforce_parser = args.enforce_parser.strip() + if enforce_parser == "": + args.enforce_parser = None + else: + valid_parsers = { + "passthrough", + "json", + "mistral", + "qwen", + "qwen_coder", + "pythonic", + "llama", + "deepseek", + "glm45_moe", + "glm47_moe", + "step3", + "kimik2", + "minimax_m2", + } + if enforce_parser not in valid_parsers: + valid_list = ", ".join(sorted(valid_parsers)) + raise ValueError( + f"Invalid --enforce-parser '{enforce_parser}'. Valid parsers: {valid_list}" + ) + args.enforce_parser = enforce_parser return args def run_server(args): @@ -73,6 +100,7 @@ def run_server(args): model_id=args.m, weight_path=args.w, weight_file=args.f, + enforce_parser=args.enforce_parser, max_num_seqs=max_num_seqs, max_model_len=args.max_model_len, max_tokens=args.max_tokens, diff --git a/src/api.rs b/src/api.rs index 4b0cf14a..842360ba 100644 --- a/src/api.rs +++ b/src/api.rs @@ -124,6 +124,7 @@ impl EngineBuilder { None, None, None, + None, self.isq, Some(self.device_ids.clone().unwrap_or(vec![0]).len()), self.device_ids.clone(), diff --git a/src/core/runner.rs b/src/core/runner.rs index 99346c01..f84e6f57 100644 --- a/src/core/runner.rs +++ b/src/core/runner.rs @@ -576,7 +576,7 @@ impl ModelRunner { // Log thinking parameter only from first rank to avoid duplicate logs in multi-GPU if self.is_first_rank && seqs[0].num_cached_tokens == 0 { - crate::log_warn!( + crate::log_info!( "User's thinking preference for reasoning models: {:?}", user_params.thinking ); diff --git a/src/core/scheduler.rs b/src/core/scheduler.rs index d9f7e58b..3744a6cb 100644 --- a/src/core/scheduler.rs +++ b/src/core/scheduler.rs @@ -566,7 +566,7 @@ impl Scheduler { let mut seq = seq.clone(); seq.num_cached_tokens += CHUNK_SIZE; //current prefilled CHUNK_SIZE seq.status = SequenceStatus::Waiting; - crate::log_warn!( + crate::log_info!( "Seq {} - chunk prefilled {} (remain {} tokens)", seq.id, seq.num_cached_tokens, diff --git a/src/main.rs b/src/main.rs index 6e48446d..e0ab6f7c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ use colored::Colorize; use reedline::{DefaultPrompt, DefaultPromptSegment, Reedline, Signal}; use serde_json; use std::sync::Arc; +use tool_parser::ParserFactory; use vllm_rs::core::engine::StreamItem; use vllm_rs::core::engine::GLOBAL_RT; use vllm_rs::core::{engine::LLMEngine, GenerationOutput}; @@ -27,6 +28,17 @@ async fn main() -> Result<()> { candle_core::bail!("Must provide model_id or weight_path or weight_file!"); } + if let Some(ref enforced) = args.enforce_parser { + let parsers = ParserFactory::new().list_parsers(); + if !parsers.contains(enforced) { + candle_core::bail!( + "Invalid enforce-parser '{}'. Valid parsers: {}", + enforced, + parsers.join(", ") + ); + } + } + let dtype = get_dtype(args.dtype); let (max_num_seqs, interactive) = if args.batch.is_some() { @@ -175,6 +187,7 @@ async fn main() -> Result<()> { args.weight_file, args.hf_token, args.hf_token_path, + args.enforce_parser.clone(), Some(std::cmp::max(max_num_seqs, prompts.len())), None, max_model_len, diff --git a/src/mcp/manager.rs b/src/mcp/manager.rs index f6fa2cc3..701b4a3d 100644 --- a/src/mcp/manager.rs +++ b/src/mcp/manager.rs @@ -392,9 +392,9 @@ fn map_mcp_tools( ); Tool { tool_type: "function".to_string(), - function: crate::tools::FunctionDefinition { + function: crate::tools::Function { name: prefixed_name, - description: tool.description.unwrap_or_default(), + description: tool.description, parameters: tool.input_schema, strict: None, }, diff --git a/src/py/mod.rs b/src/py/mod.rs index d78dd9e2..e08b57e8 100644 --- a/src/py/mod.rs +++ b/src/py/mod.rs @@ -259,7 +259,7 @@ impl Message { impl EngineConfig { #[new] #[pyo3(signature = (model_id=None, weight_path=None, weight_file=None, - hf_token=None, hf_token_path=None, + hf_token=None, hf_token_path=None, enforce_parser=None, max_num_seqs=Some(32), config_model_len=None, max_model_len=Some(1024), max_tokens=None, isq=None, num_shards=Some(1), device_ids=None, generation_cfg=None, seed=None, prefix_cache=None, prefix_cache_max_tokens=None, @@ -273,6 +273,7 @@ impl EngineConfig { weight_file: Option, hf_token: Option, hf_token_path: Option, + enforce_parser: Option, max_num_seqs: Option, config_model_len: Option, max_model_len: Option, @@ -321,6 +322,7 @@ impl EngineConfig { weight_file, hf_token, hf_token_path, + enforce_parser, num_blocks: 128, //placeholder kv_fraction, cpu_mem_fold, diff --git a/src/server/claude_server.rs b/src/server/claude_server.rs index 7756c958..701d5a29 100644 --- a/src/server/claude_server.rs +++ b/src/server/claude_server.rs @@ -6,7 +6,6 @@ use crate::core::engine::{LLMEngine, StreamItem}; use crate::server::logger::ChatCompletionLogger; use crate::server::parser::{ParserState, StreamResult, StreamToolParser}; use crate::tools::helpers::{build_tool_schema_map, filter_tool_calls}; -use crate::tools::parser::ToolParser; use crate::tools::{Tool, ToolCall, ToolChoice, ToolFormat}; use crate::utils::config::SamplingParams; use axum::{ @@ -423,7 +422,7 @@ fn claude_tools_to_tools(tools: &[ClaudeTool]) -> Vec { .iter() .map(|tool| { let description = tool.description.clone().unwrap_or_default(); - Tool::function(&tool.name, description) + crate::tools::function_tool(&tool.name, description) .parameters_schema(tool.input_schema.clone()) .build() }) @@ -605,7 +604,11 @@ fn convert_claude_message(message: &ClaudeMessage) -> Result, S } flush_content_message(&mut out, role, &mut content_items); let args = serde_json::to_string(input).map_err(|err| err.to_string())?; - tool_calls.push(ToolCall::new(id.clone(), name.clone(), args)); + tool_calls.push(crate::tools::new_tool_call( + id.clone(), + name.clone(), + args, + )); } ClaudeContentBlock::ToolResult { tool_use_id, @@ -694,7 +697,8 @@ fn tool_calls_to_blocks(tool_calls: &[ToolCall]) -> Vec { tool_calls .iter() .map(|call| { - let input = serde_json::from_str(&call.function.arguments).unwrap_or_else(|_| { + let args_str = call.function.arguments.as_deref().unwrap_or("{}"); + let input = serde_json::from_str(args_str).unwrap_or_else(|_| { crate::log_warn!( "Failed to parse tool arguments for '{}'", call.function.name @@ -762,7 +766,7 @@ fn send_tool_use_block( }); stream_ctx.send_json_event("content_block_start", &start_payload)?; - let input_json = call.function.arguments.clone(); + let input_json = call.function.arguments.clone().unwrap_or_default(); let delta = ClaudeContentBlockDeltaEvent { event_type: "content_block_delta", @@ -946,7 +950,12 @@ fn log_tool_calls(label: &str, seq_id: usize, tool_calls: &[ToolCall]) { let summary = tool_calls .iter() .map(|call| { - let args = call.function.arguments.replace('\n', " "); + let args = call + .function + .arguments + .as_deref() + .unwrap_or("") + .replace('\n', " "); let truncated = if args.len() > 160 { let snippet: String = args.chars().take(160).collect(); format!("{}...", snippet) @@ -1158,10 +1167,17 @@ pub async fn messages( }; let _tool_choice = tool_choice_to_openai(&request.tool_choice); - let (model_type, tool_config) = { + let (model_type, tool_config, engine_config) = { let e = data.engine.read(); - (e.model_type.clone(), e.tool_config.clone()) + ( + e.model_type.clone(), + e.tool_config.clone(), + e.econfig.clone(), + ) }; + let parser_model_id = + super::resolve_engine_model_id(&engine_config).unwrap_or_else(|| model_id.clone()); + let enforce_parser = engine_config.enforce_parser.clone(); if !resolved_tools.is_empty() { let tool_prompt_template = data.engine.read().econfig.tool_prompt_template.clone(); @@ -1226,10 +1242,12 @@ pub async fn messages( let engine_clone = data.engine.clone(); let params_clone = params.clone(); let stream_model_id = model_id.clone(); + let stream_parser_model_id = parser_model_id.clone(); let stream_model_type = model_type.clone(); let stream_tool_config = tool_config.clone(); let stream_tool_schemas = tool_schemas.clone(); let forced_tool_name = forced_tool_name.clone(); + let stream_tools = resolved_tools.clone(); if let Some(ref l) = logger { l.log_start_response(); } @@ -1337,8 +1355,10 @@ pub async fn messages( let mut pending_tool_calls: Vec = Vec::new(); let mut tool_parser = StreamToolParser::new_with_config( &stream_model_type, - stream_model_id.clone(), + stream_parser_model_id.clone(), stream_tool_config, + stream_tools.clone(), + enforce_parser.clone(), ); let should_parse_tools = params_clone.mcp_mode.is_some(); @@ -1376,7 +1396,7 @@ pub async fn messages( total_decoded_tokens += 1; if should_parse_tools { - match tool_parser.process_token(token_id, &token) { + match tool_parser.process_token(token_id, &token).await { StreamResult::Content(text) => { if text.is_empty() { continue; @@ -1476,63 +1496,67 @@ pub async fn messages( total_decoded_tokens = final_decoded_length; if should_parse_tools { - match tool_parser.state() { - ParserState::Buffering => { - if let Some(mut parsed) = tool_parser.finalize() { - pending_tool_calls.append(&mut parsed); - } else { - let buffer = tool_parser.take_buffer(); - if !buffer.is_empty() { - let _ = send_text_with_start( - &stream_ctx, - &mut text_block_started, - text_block_index, - &buffer, - ); - } - } - } - ParserState::MaybeStart => { - let buffer = tool_parser.take_buffer(); - if !buffer.is_empty() { - let _ = send_text_with_start( - &stream_ctx, - &mut text_block_started, - text_block_index, - &buffer, - ); - } + if matches!(tool_parser.state(), ParserState::Buffering) { + let buffer = tool_parser.take_buffer(); + if !buffer.is_empty() { + let _ = send_text_with_start( + &stream_ctx, + &mut text_block_started, + text_block_index, + &buffer, + ); } - ParserState::Normal => {} } } let (tool_calls, has_tool_calls) = if pending_tool_calls.is_empty() { (Vec::new(), false) } else { - let (valid, invalid) = filter_tool_calls( + let (validated_calls, invalid) = filter_tool_calls( &pending_tool_calls, stream_tool_schemas.as_ref(), ); if !invalid.is_empty() { crate::log_warn!( - "[Seq {}] Dropping {} invalid tool call(s)", + "[Seq {}] Found {} invalid tool call(s)", seq_id, invalid.len() ); - log_tool_calls("Invalid", seq_id, &invalid); - if let Some(ref l) = stream_logger { - l.log_tool_calls("Invalid", &invalid); - } } - if valid.is_empty() { + let strict_mode = std::env::var("VLLM_RS_STRICT_TOOL_CALL") + .map(|v| v == "1" || v.eq_ignore_ascii_case("true")) + .unwrap_or(false); + + let final_tool_calls = if strict_mode { + if !invalid.is_empty() { + crate::log_warn!( + "[Seq {}] Strict mode enabled, dropping invalid calls", + seq_id + ); + } + validated_calls + } else { + if !invalid.is_empty() { + crate::log_warn!( + "[Seq {}] Strict mode disabled, keeping invalid calls", + seq_id + ); + log_tool_calls("Invalid", seq_id, &invalid); + if let Some(ref l) = stream_logger { + l.log_tool_calls("Invalid", &invalid); + } + } + pending_tool_calls + }; + + if final_tool_calls.is_empty() { (Vec::new(), false) } else { - log_tool_calls("Valid", seq_id, &valid); + log_tool_calls("Valid", seq_id, &final_tool_calls); if let Some(ref l) = stream_logger { - l.log_tool_calls("Valid", &valid); + l.log_tool_calls("Valid", &final_tool_calls); } - (valid, true) + (final_tool_calls, true) } }; @@ -1807,19 +1831,39 @@ pub async fn messages( } }; - let tool_parser = ToolParser::new(); - let parsed_calls = tool_parser.parse(&output.decode_output); - let (valid_calls, invalid_calls) = filter_tool_calls(&parsed_calls, tool_schemas.as_ref()); + let tool_parser = StreamToolParser::new_with_config( + &model_type, + parser_model_id.clone(), + tool_config.clone(), + resolved_tools.clone(), + enforce_parser.clone(), + ); + let parsed_calls = tool_parser + .parse_complete_with_fallback(&output.decode_output) + .await; + let (validated_calls, invalid_calls) = + filter_tool_calls(&parsed_calls, tool_schemas.as_ref()); + if !invalid_calls.is_empty() { - crate::log_warn!( - "Dropping {} invalid tool call(s) for Claude response", - invalid_calls.len() - ); - log_tool_calls("Invalid", output.seq_id, &invalid_calls); - if let Some(ref l) = logger { - l.log_tool_calls("Invalid", &invalid_calls); - } + crate::log_warn!("Found {} invalid tool call(s)", invalid_calls.len()); } + + let strict_mode = std::env::var("VLLM_RS_STRICT_TOOL_CALL") + .map(|v| v == "1" || v.eq_ignore_ascii_case("true")) + .unwrap_or(false); + + let valid_calls = if strict_mode { + if !invalid_calls.is_empty() { + crate::log_warn!("Strict mode enabled, dropping invalid calls"); + } + validated_calls + } else { + if !invalid_calls.is_empty() { + crate::log_warn!("Strict mode disabled, keeping invalid calls"); + } + parsed_calls + }; + if !valid_calls.is_empty() { log_tool_calls("Valid", output.seq_id, &valid_calls); if let Some(ref l) = logger { @@ -2098,12 +2142,12 @@ mod tests { let schema = SchemaBuilder::object() .string_prop("path", "Path to list", true) .build(); - let tools = vec![Tool::function("list_files", "List files") + let tools = vec![crate::tools::function_tool("list_files", "List files") .parameters_schema(schema) .build()]; let schemas = build_tool_schema_map(&tools); - let valid_call = ToolCall::new("call_1", "list_files", r#"{"path": "."}"#); - let invalid_call = ToolCall::new("call_2", "list_files", r#"{"dir": "."}"#); + let valid_call = crate::tools::new_tool_call("call_1", "list_files", r#"{"path": "."}"#); + let invalid_call = crate::tools::new_tool_call("call_2", "list_files", r#"{"dir": "."}"#); let (valid, invalid) = filter_tool_calls(&[valid_call, invalid_call], &schemas); assert_eq!(valid.len(), 1); diff --git a/src/server/mod.rs b/src/server/mod.rs index 3b8d867e..d8394213 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -26,6 +26,7 @@ use parking_lot::RwLock; use rustchatui::start_ui_server; use serde_json::json; use std::collections::HashMap; +use std::path::Path; use std::sync::Arc; use tower_http::cors::{Any, CorsLayer}; @@ -51,6 +52,41 @@ pub struct ChatCompletionRequest { pub tool_choice: Option, } +pub fn resolve_engine_model_id(econfig: &EngineConfig) -> Option { + if let Some(model_id) = &econfig.model_id { + if !model_id.trim().is_empty() { + return Some(model_id.clone()); + } + } + + if let Some(weight_path) = &econfig.weight_path { + let trimmed = weight_path.trim_end_matches(['/', '\\']); + let path = Path::new(trimmed); + if let Some(name) = path.file_name().and_then(|s| s.to_str()) { + if !name.is_empty() { + return Some(name.to_string()); + } + } + if let Some(component) = path.components().last() { + let name = component.as_os_str().to_string_lossy().to_string(); + if !name.is_empty() { + return Some(name); + } + } + } + + if let Some(weight_file) = &econfig.weight_file { + let path = Path::new(weight_file); + if let Some(name) = path.file_name().and_then(|s| s.to_str()) { + if !name.is_empty() { + return Some(name.to_string()); + } + } + } + + None +} + #[derive(Deserialize)] #[serde(rename_all = "snake_case")] pub enum EncodingFormat { @@ -183,6 +219,17 @@ pub struct ChatChoice { pub finish_reason: Option, } +/// Public tool call structure with correct serialization fields +#[derive(Serialize, Debug, Clone)] +pub struct PublicToolCall { + #[serde(skip_serializing_if = "Option::is_none")] + pub index: Option, + pub id: String, + #[serde(rename = "type")] + pub type_: String, + pub function: crate::tools::FunctionCall, +} + /// Message in the response (may contain tool calls) #[derive(Serialize)] pub struct ChatResponseMessage { @@ -190,7 +237,7 @@ pub struct ChatResponseMessage { #[serde(skip_serializing_if = "Option::is_none")] pub content: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, + pub tool_calls: Option>, } #[derive(Serialize, Debug)] @@ -228,7 +275,7 @@ pub struct Delta { #[serde(skip_serializing_if = "Option::is_none")] pub content: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, + pub tool_calls: Option>, } #[derive(Serialize)] @@ -442,6 +489,10 @@ pub struct Args { #[arg(long = "f")] pub weight_file: Option, + /// Enforce a specific tool-call parser (e.g., qwen, qwen_coder, json) + #[arg(long, default_value = None)] + pub enforce_parser: Option, + pub hf_token: Option, pub hf_token_path: Option, @@ -570,8 +621,9 @@ pub async fn execute_mcp_tool_calls_async( followup_messages.push(ChatMessage::with_tool_calls(tool_calls.clone())); for call in &tool_calls { - let args_value: serde_json::Value = serde_json::from_str(&call.function.arguments) - .unwrap_or_else(|_| serde_json::json!({"raw": call.function.arguments})); + let args_str = call.function.arguments.as_deref().unwrap_or("{}"); + let args_value: serde_json::Value = + serde_json::from_str(args_str).unwrap_or_else(|_| serde_json::json!({"raw": args_str})); let args_map = args_value .as_object() .cloned() diff --git a/src/server/parser.rs b/src/server/parser.rs index 7f530126..743da9d8 100644 --- a/src/server/parser.rs +++ b/src/server/parser.rs @@ -3,19 +3,22 @@ //! Handles model-specific tool call tokens and formats. use crate::server::{ChatChoiceChunk, ChatCompletionChunk, Delta}; -use crate::tools::{FunctionCall, ToolCall}; +use crate::tools::{Tool, ToolCall}; use crate::utils::config::ModelType; use serde_json::Value; use std::collections::HashSet; use tokenizers::Tokenizer; - +use tool_parser::{ + types::{StreamingParseResult, ToolCallItem}, + ParserFactory, ToolParser as ExternalToolParser, +}; /// Parser state for streaming tool call detection #[derive(Debug, Clone, PartialEq)] pub enum ParserState { /// Normal streaming mode - tokens pass through Normal, /// Potential start tag detected (partial match) - MaybeStart, + // MaybeStart, /// Buffering mode - accumulating confirmed tool call content Buffering, } @@ -201,14 +204,15 @@ pub struct StreamToolParser { buffer: String, model_id: String, parse_strategy: String, + parser: Box, + tools: Vec, + streaming_calls: Vec, // Accumulated output for final parsing accumulated_output: String, // Reasoning block tracking active_reasoning_end: Option<&'static str>, // Code block tracking in_code_block: bool, - // Tool call index counter - tool_call_index: usize, } /// Reasoning marker pairs: (start, end) @@ -223,27 +227,70 @@ impl StreamToolParser { /// Create a new parser for the given model type pub fn new(model_type: ModelType, model_id: String) -> Self { let config = ToolConfig::for_model_type(&model_type); - Self::new_with_config(&model_type, model_id, config) + Self::new_with_config(&model_type, model_id, config, Vec::new(), None) } /// Create a new parser with a pre-validated tool config - pub fn new_with_config(model_type: &ModelType, model_id: String, config: ToolConfig) -> Self { + pub fn new_with_config( + model_type: &ModelType, + model_id: String, + config: ToolConfig, + tools: Vec, + enforce_parser: Option, + ) -> Self { let parse_strategy = match model_type { ModelType::Mistral | ModelType::Mistral3VL => "mistral_list", _ => "json", } .to_string(); + let factory = ParserFactory::new(); + let parser_name = if let Some(name) = enforce_parser.as_ref().and_then(|s| { + let trimmed = s.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed) + } + }) { + if !factory.registry().has_parser(name) { + let valid = factory.list_parsers().join(", "); + panic!( + "Invalid enforce-parser '{}'. Valid parsers: {}", + name, valid + ); + } + name + } else { + Self::parser_name_for_model(model_type, &model_id) + }; + if !tools.is_empty() { + crate::log_info!( + "Tool parser selected: {} (model_id={}, enforce_parser={})", + parser_name, + model_id, + enforce_parser.as_deref().unwrap_or("none") + ); + } + let parser = factory + .registry() + .create_parser(parser_name) + .or_else(|| factory.registry().create_for_model(&model_id)) + .or_else(|| factory.registry().create_parser("passthrough")) + .expect("tool parser available"); + Self { config, state: ParserState::Normal, buffer: String::new(), model_id, parse_strategy, + parser, + tools, + streaming_calls: Vec::new(), accumulated_output: String::new(), active_reasoning_end: None, in_code_block: false, - tool_call_index: 0, } } @@ -274,7 +321,7 @@ impl StreamToolParser { /// Process a single incoming token. /// Returns StreamResult indicating what action to take. - pub fn process_token(&mut self, token_id: u32, token_text: &str) -> StreamResult { + pub async fn process_token(&mut self, token_id: u32, token_text: &str) -> StreamResult { // Always accumulate self.accumulated_output.push_str(token_text); @@ -314,16 +361,11 @@ impl StreamToolParser { if self.is_start_token(token_id, token_text) { self.state = ParserState::Buffering; self.buffer.clear(); - - if let Some(pos) = token_text.find(&self.config.start_token_str) { - let before = &token_text[..pos]; - let after = &token_text[pos + self.config.start_token_str.len()..]; - if !after.is_empty() { - self.buffer.push_str(after); - } - if !before.is_empty() { - return StreamResult::Content(before.to_string()); - } + self.buffer.push_str(token_text); + self.streaming_calls.clear(); + if let Ok(result) = self.parser.parse_incremental(token_text, &self.tools).await + { + self.apply_streaming_result(&result); } crate::log_info!( @@ -333,56 +375,17 @@ impl StreamToolParser { ); return StreamResult::Buffering; } - - // Check for partial tag match at end of current token - if !self.config.has_start_tokens() { - if let Some((prefix, partial)) = self.split_partial_start(token_text) { - self.state = ParserState::MaybeStart; - self.buffer.clear(); - self.buffer.push_str(&partial); - return if prefix.is_empty() { - StreamResult::Buffering - } else { - StreamResult::Content(prefix) - }; - } - } - // Normal content StreamResult::Content(token_text.to_string()) } - ParserState::MaybeStart => { + ParserState::Buffering => { self.buffer.push_str(token_text); - - if let Some(tag_pos) = self.buffer.find(&self.config.start_token_str) { - let before = self.buffer[..tag_pos].to_string(); - let after = - self.buffer[tag_pos + self.config.start_token_str.len()..].to_string(); - self.buffer.clear(); - if !after.is_empty() { - self.buffer.push_str(&after); + if let Ok(result) = self.parser.parse_incremental(token_text, &self.tools).await { + if !result.calls.is_empty() { + crate::log_info!("Stream parsing: {:?}", result.calls); } - self.state = ParserState::Buffering; - return if before.is_empty() { - StreamResult::Buffering - } else { - StreamResult::Content(before) - }; + self.apply_streaming_result(&result); } - - if self.partial_suffix_len(&self.buffer) > 0 { - return StreamResult::Buffering; - } - - // False alarm - not a tool call tag - self.state = ParserState::Normal; - let flushed = self.buffer.clone(); - self.buffer.clear(); - StreamResult::FlushBuffer(flushed) - } - ParserState::Buffering => { - self.buffer.push_str(token_text); - let end_reached = self.is_end_token(token_id, token_text) || self.buffer_has_end_tag() || self.maybe_complete_mistral_list(); @@ -393,20 +396,28 @@ impl StreamToolParser { token_id ); - let tool_calls = self.parse_buffer(); + if let Some(unstreamed) = self.parser.get_unstreamed_tool_args() { + self.apply_stream_items(&unstreamed); + } + let mut tool_calls = self.build_tool_calls_from_streaming(); + if tool_calls.is_empty() { + crate::log_info!( + "Fallback to non-stream parsing for buffer: {}", + self.buffer + ); + tool_calls = self.parse_complete_with_fallback(&self.buffer).await; + } let result = if tool_calls.is_empty() { // Parse failed - return buffered content - crate::log_error!( - "Unable to parse tool call buffer: {}\n of accumulated buffer: {}", - self.buffer, - self.accumulated_output - ); + crate::log_error!("Unable to parse tool call buffer: {}", self.buffer,); StreamResult::FlushBuffer(self.buffer.clone()) } else { StreamResult::ToolCalls(tool_calls) }; + self.parser.reset(); self.buffer.clear(); self.state = ParserState::Normal; + self.streaming_calls.clear(); return result; } @@ -415,31 +426,6 @@ impl StreamToolParser { } } - /// Finalize parsing when stream ends - pub fn finalize(&mut self) -> Option> { - match self.state { - ParserState::Buffering => { - if self.buffer.is_empty() { - self.state = ParserState::Normal; - return None; - } - let tool_calls = self.parse_buffer(); - if !tool_calls.is_empty() { - self.buffer.clear(); - self.state = ParserState::Normal; - return Some(tool_calls); - } - // Leave buffer intact so caller can flush it. - self.state = ParserState::Normal; - } - ParserState::MaybeStart => { - self.state = ParserState::Normal; - } - ParserState::Normal => {} - } - None - } - /// Drain the buffer and reset parser state. pub fn take_buffer(&mut self) -> String { self.state = ParserState::Normal; @@ -469,137 +455,82 @@ impl StreamToolParser { text.contains(&self.config.end_token_str) } - /// Parse buffered content into tool calls - fn parse_buffer(&mut self) -> Vec { - let mut clean_text = self.buffer.trim().to_string(); - if self.should_strip_end_tag() { - if let Some(pos) = clean_text.rfind(&self.config.end_token_str) { - clean_text.truncate(pos); - } + fn apply_streaming_result(&mut self, result: &StreamingParseResult) { + if !result.calls.is_empty() { + self.apply_stream_items(&result.calls); } - let mut calls = Vec::new(); + } - // Strategy 1: Mistral List [ {...}, {...} ] - if self.parse_strategy == "mistral_list" && clean_text.starts_with('[') { - if let Ok(list) = serde_json::from_str::>(&clean_text) { - for item in list.iter() { - if let Some(call) = self.json_to_tool_call(item) { - calls.push(call); - } - } + fn apply_stream_items(&mut self, items: &[ToolCallItem]) { + for item in items { + if self.streaming_calls.len() <= item.tool_index { + self.streaming_calls + .resize_with(item.tool_index + 1, StreamingToolCallState::default); } - } - // Strategy 2: Single JSON Object (Qwen, Llama, Phi) - else if let Ok(item) = serde_json::from_str::(&clean_text) { - if let Some(call) = self.json_to_tool_call(&item) { - calls.push(call); + let state = &mut self.streaming_calls[item.tool_index]; + if let Some(name) = &item.name { + state.name = Some(name.clone()); + } + if !item.parameters.is_empty() { + state.arguments.push_str(&item.parameters); } } - // Strategy 3: QwenCoder XML-style function tags - else if clean_text.starts_with("") { - // Extract function name from tag - let func_start = "').unwrap_or(0); - if func_end > func_start { - let func_name = &clean_text[func_start..func_end]; - - // Find parameter section - let params_start = clean_text.find(""); - - let mut params = std::collections::HashMap::new(); - - if let (Some(start), Some(end)) = (params_start, params_end) { - let param_content = &clean_text[start..end]; - - // Parse all parameter tags - let mut pos = 0; - while pos < param_content.len() { - let param_start_tag = param_content[pos..].find("").unwrap_or(0) + key_start; - - if key_end > key_start { - let key = ¶m_content[key_start..key_end]; - - let value_start = key_end + 1; - let value_end = param_content[value_start..] - .find("") - .unwrap_or(0) - + value_start; - - if value_end > value_start { - let value = ¶m_content[value_start..value_end]; - params.insert(key.to_string(), value.trim().to_string()); - } - } - - // Move position past this parameter - let next_pos = param_content[pos..].find("").unwrap_or(0) - + pos - + "".len(); - if next_pos <= pos { - break; - } - pos = next_pos; - } - } - - if let Ok(args) = serde_json::to_string(¶ms) { - let call = ToolCall { - index: Some(self.tool_call_index), - id: format!("call_{}", uuid::Uuid::new_v4().simple()), - call_type: "function".to_string(), - function: FunctionCall { - name: func_name.to_string(), - arguments: args, - }, - }; + } - self.tool_call_index += 1; - calls.push(call); - } - } + fn build_tool_calls_from_streaming(&mut self) -> Vec { + let mut calls = Vec::new(); + crate::log_info!("Building tool call: {:?}", self.streaming_calls); + for state in &self.streaming_calls { + let Some(name) = &state.name else { continue }; + let args = if state.arguments.trim().is_empty() { + "{}".to_string() + } else { + state.arguments.clone() + }; + calls.push(crate::tools::new_tool_call( + format!("call_{}", uuid::Uuid::new_v4().simple()), + name.clone(), + args, + )); } - // Strategy 4: Repair unbalanced JSON and retry - else if let Some(repaired) = self.repair_unbalanced_json(&clean_text) { - if repaired != clean_text { - crate::log_warn!("Tool call JSON missing closing braces; attempting repair"); + calls + } + + pub async fn parse_complete_with_fallback(&self, text: &str) -> Vec { + let mut parsed_calls = match self.parser.parse_complete(text).await { + Ok((_normal_text, calls)) => calls, + Err(err) => { + crate::log_warn!("Tool parse failed: {:?}", err); + Vec::new() } - if let Ok(item) = serde_json::from_str::(&repaired) { - if let Some(call) = self.json_to_tool_call(&item) { - calls.push(call); + }; + + if parsed_calls.is_empty() && text.contains(" Option<(String, String)> { - let tag = &self.config.start_token_str; - let suffix_len = self.partial_suffix_len(text); - if suffix_len > 0 && suffix_len < tag.len() { - let prefix = text[..text.len() - suffix_len].to_string(); - let partial = text[text.len() - suffix_len..].to_string(); - return Some((prefix, partial)); - } - None - } - - fn partial_suffix_len(&self, text: &str) -> usize { - let tag = &self.config.start_token_str; - let max = std::cmp::min(tag.len(), text.len()); - for i in (1..=max).rev() { - if text.ends_with(&tag[..i]) { - return i; + if parsed_calls.is_empty() + && self.config.start_token_str.starts_with('<') + && self.config.end_token_str.starts_with('<') + { + let stripped = self.strip_tool_tags(text); + let factory = ParserFactory::new(); + if let Some(json_parser) = factory.registry().create_parser("json") { + if let Ok((_normal_text, calls)) = json_parser.parse_complete(&stripped).await { + parsed_calls = calls; + } } } - 0 + + parsed_calls + .into_iter() + .map(crate::tools::tool_call_from_parser) + .collect() } fn buffer_has_end_tag(&self) -> bool { @@ -626,92 +557,35 @@ impl StreamToolParser { serde_json::from_str::>(trimmed).is_ok() } - fn should_strip_end_tag(&self) -> bool { - let end_tag = self.config.end_token_str.as_str(); - if end_tag.is_empty() { - return false; - } - if self.parse_strategy == "mistral_list" && end_tag == "]" { - return false; - } - end_tag.starts_with('<') - } - - fn repair_unbalanced_json(&self, text: &str) -> Option { - let trimmed = text.trim(); - if !(trimmed.starts_with('{') || trimmed.starts_with('[')) { - return None; - } - - let mut in_string = false; - let mut escape = false; - let mut open_braces = 0usize; - let mut close_braces = 0usize; - let mut open_brackets = 0usize; - let mut close_brackets = 0usize; - - for ch in trimmed.chars() { - if escape { - escape = false; - continue; - } - match ch { - '\\' if in_string => { - escape = true; - } - '"' => { - in_string = !in_string; + fn parser_name_for_model(model_type: &ModelType, model_id: &str) -> &'static str { + let model_lower = model_id.to_ascii_lowercase(); + match model_type { + ModelType::LLaMa => "llama", + ModelType::Mistral | ModelType::Mistral3VL => "mistral", + ModelType::Qwen3 | ModelType::Qwen3MoE | ModelType::Qwen3VL => { + if model_lower.contains("coder") { + "qwen_coder" + } else { + "qwen" } - '{' if !in_string => open_braces += 1, - '}' if !in_string => close_braces += 1, - '[' if !in_string => open_brackets += 1, - ']' if !in_string => close_brackets += 1, - _ => {} } + ModelType::Gemma | ModelType::Gemma3 => "json", + ModelType::Phi | ModelType::Phi4 => "qwen", + ModelType::GLM4 | ModelType::GLM4MoE => "glm47_moe", + ModelType::Yi | ModelType::StableLM => "qwen", + ModelType::DeepSeek => "deepseek", } + } - if in_string { - return None; - } - if close_braces > open_braces || close_brackets > open_brackets { - return None; - } - - if open_braces == close_braces && open_brackets == close_brackets { - return None; - } - - let mut fixed = trimmed.to_string(); - if open_brackets > close_brackets { - fixed.push_str(&"]".repeat(open_brackets - close_brackets)); + fn strip_tool_tags(&self, text: &str) -> String { + let mut output = text.to_string(); + if !self.config.start_token_str.is_empty() { + output = output.replace(&self.config.start_token_str, ""); } - if open_braces > close_braces { - fixed.push_str(&"}".repeat(open_braces - close_braces)); + if !self.config.end_token_str.is_empty() { + output = output.replace(&self.config.end_token_str, ""); } - Some(fixed) - } - - /// Convert JSON value to ToolCall - fn json_to_tool_call(&mut self, item: &Value) -> Option { - let name = item["name"].as_str()?.to_string(); - let arguments = if let Some(args) = item.get("arguments") { - if args.is_string() { - args.as_str().unwrap_or("{}").to_string() - } else { - args.to_string() - } - } else { - "{}".to_string() - }; - - let call = ToolCall { - index: Some(self.tool_call_index), - id: format!("call_{}", uuid::Uuid::new_v4().simple()), - call_type: "function".to_string(), - function: FunctionCall { name, arguments }, - }; - self.tool_call_index += 1; - Some(call) + output } // --- Chunk creation helpers (for use by server.rs) --- @@ -753,7 +627,18 @@ impl StreamToolParser { index: 0, delta: Delta { content: None, - tool_calls: Some(tools), + tool_calls: Some( + tools + .into_iter() + .enumerate() + .map(|(i, tc)| crate::server::PublicToolCall { + index: Some(i), + id: tc.id, + type_: tc.tool_type, + function: tc.function, + }) + .collect(), + ), }, finish_reason: None, error: None, @@ -763,6 +648,12 @@ impl StreamToolParser { } } +#[derive(Debug, Clone, Default)] +struct StreamingToolCallState { + name: Option, + arguments: String, +} + #[cfg(test)] mod tests { use super::*; @@ -783,33 +674,50 @@ mod tests { assert_eq!(config.start_token_str, ""); } - #[test] - fn test_parser_normal_content() { - let mut parser = StreamToolParser::new(ModelType::Qwen3, "qwen3".to_string()); - match parser.process_token(0, "Hello world") { + #[tokio::test] + async fn test_parser_normal_content() { + let tools = vec![crate::tools::function_tool("test", "desc").build()]; + let mut parser = StreamToolParser::new_with_config( + &ModelType::Qwen3, + "qwen3".to_string(), + ToolConfig::for_model_type(&ModelType::Qwen3), + tools, + None, + ); + match parser.process_token(0, "Hello world").await { StreamResult::Content(s) => assert_eq!(s, "Hello world"), _ => panic!("Expected Content"), } } - #[test] - fn test_parser_tool_call_detection() { - let mut parser = StreamToolParser::new(ModelType::Qwen3, "qwen3".to_string()); + #[tokio::test] + async fn test_parser_tool_call_detection() { + let tools = vec![crate::tools::function_tool("test", "desc").build()]; + let mut parser = StreamToolParser::new_with_config( + &ModelType::Qwen3, + "qwen3".to_string(), + ToolConfig::for_model_type(&ModelType::Qwen3), + tools, + None, + ); // Start tag triggers buffering - match parser.process_token(151657, "") { + match parser.process_token(151657, "").await { StreamResult::Buffering => {} _ => panic!("Expected Buffering on start tag"), } // Content is buffered - match parser.process_token(0, r#"{"name": "test", "arguments": {}}"#) { + match parser + .process_token(0, r#"{"name": "test", "arguments": {}}"#) + .await + { StreamResult::Buffering => {} _ => panic!("Expected Buffering"), } // End tag triggers parsing - match parser.process_token(151658, "") { + match parser.process_token(151658, "").await { StreamResult::ToolCalls(calls) => { assert_eq!(calls.len(), 1); assert_eq!(calls[0].function.name, "test"); @@ -818,24 +726,34 @@ mod tests { } } - #[test] - fn test_parser_partial_start_text_mode() { - let mut parser = StreamToolParser::new(ModelType::Phi, "phi".to_string()); + #[tokio::test] + async fn test_parser_partial_start_text_mode() { + let tools = vec![crate::tools::function_tool("test", "desc").build()]; + let mut parser = StreamToolParser::new_with_config( + &ModelType::Phi, + "phi".to_string(), + ToolConfig::for_model_type(&ModelType::Phi), + tools, + None, + ); // Partial start tag splits across tokens - match parser.process_token(0, " {} _ => panic!("Expected Buffering on partial start"), } - match parser.process_token(0, "call>") { + match parser.process_token(0, "call>").await { StreamResult::Buffering => {} _ => panic!("Expected Buffering on completed start"), } - match parser.process_token(0, r#"{"name": "test", "arguments": {}}"#) { + match parser + .process_token(0, r#"{"name": "test", "arguments": {}}"#) + .await + { StreamResult::Buffering => {} _ => panic!("Expected Buffering"), } - match parser.process_token(0, "") { + match parser.process_token(0, "").await { StreamResult::ToolCalls(calls) => { assert_eq!(calls.len(), 1); assert_eq!(calls[0].function.name, "test"); @@ -844,12 +762,19 @@ mod tests { } } - #[test] - fn test_parser_token_id_strict_match() { - let mut parser = StreamToolParser::new(ModelType::Qwen3, "qwen3".to_string()); + #[tokio::test] + async fn test_parser_token_id_strict_match() { + let tools = vec![crate::tools::function_tool("test", "desc").build()]; + let mut parser = StreamToolParser::new_with_config( + &ModelType::Qwen3, + "qwen3".to_string(), + ToolConfig::for_model_type(&ModelType::Qwen3), + tools, + None, + ); // Text match should not trigger when token IDs are available - match parser.process_token(0, "") { + match parser.process_token(0, "").await { StreamResult::Content(text) => assert_eq!(text, ""), _ => panic!("Expected Content without token ID match"), } diff --git a/src/server/server.rs b/src/server/server.rs index b1ac4085..ee1928fb 100644 --- a/src/server/server.rs +++ b/src/server/server.rs @@ -16,7 +16,6 @@ use crate::server::parser::{ParserState, StreamResult, StreamToolParser}; use crate::tools::helpers::{ build_tool_schema_map, filter_tool_calls, log_tool_calls, resolve_tools, }; -use crate::tools::parser::ToolParser; use crate::tools::{ToolChoice, ToolFormat}; use crate::utils::config::SamplingParams; use axum::{ @@ -111,12 +110,13 @@ pub async fn chat_completion( params.presence_penalty = request.presence_penalty; params.session_id = request.session_id.clone(); params.thinking = request.thinking.clone(); - let (img_cfg, model_type, tool_config) = { + let (img_cfg, model_type, tool_config, engine_config) = { let e = data.engine.read(); ( e.img_cfg.clone(), e.model_type.clone(), e.tool_config.clone(), + e.econfig.clone(), ) }; @@ -175,6 +175,9 @@ pub async fn chat_completion( } let mut chat_messages = request.messages.clone(); + let parser_model_id = + super::resolve_engine_model_id(&engine_config).unwrap_or_else(|| model_id.clone()); + let enforce_parser = engine_config.enforce_parser.clone(); if has_tools { let tool_prompt_template = data.engine.read().econfig.tool_prompt_template.clone(); let mut tool_prompt = if let Some(template) = tool_prompt_template { @@ -259,8 +262,13 @@ pub async fn chat_completion( let _img_cfg_clone = img_cfg.clone(); let tool_config = tool_config.clone(); - let tool_parser = - StreamToolParser::new_with_config(&model_type, model_id.to_string(), tool_config); + let tool_parser = StreamToolParser::new_with_config( + &model_type, + parser_model_id.clone(), + tool_config, + resolved_tools.clone(), + enforce_parser.clone(), + ); let forced_tool_name = forced_tool_name.clone(); let stream_tool_schemas = tool_schemas.clone(); if let Some(ref l) = logger { @@ -321,7 +329,7 @@ pub async fn chat_completion( // Use StreamToolParser for all tool call detection and buffering if should_parse_tools { - match tool_parser.process_token(token_id, &token) { + match tool_parser.process_token(token_id, &token).await { StreamResult::Content(text) => { if text.is_empty() { continue; @@ -393,45 +401,18 @@ pub async fn chat_completion( )) => { total_decoded_tokens += final_decoded_length; - // Finalize tool parsing and collect any remaining tool calls + // Flush any buffered content at end of stream if should_parse_tools { - match tool_parser.state() { - ParserState::Buffering => { - // Finalize any buffered content - if let Some(mut parsed) = tool_parser.finalize() { - if !parsed.is_empty() { - crate::log_info!( - "[Seq {}] Parsed {} tool call(s)", - current_seq_id, - parsed.len() - ); - } - pending_tool_calls.append(&mut parsed); - } else { - // Parse failed - flush any remaining buffer as text - let buffer = tool_parser.take_buffer(); - if !buffer.is_empty() { - crate::log_warn!( - "[Seq {}] Tool parse failed, flushing {} chars", - current_seq_id, - buffer.len() - ); - stream_ctx.send_token(&buffer); - } - } - } - ParserState::MaybeStart => { - let buffer = tool_parser.take_buffer(); - if !buffer.is_empty() { - crate::log_warn!( - "[Seq {}] Tool parse partial, flushing {} chars", - current_seq_id, - buffer.len() - ); - stream_ctx.send_token(&buffer); - } + if matches!(tool_parser.state(), ParserState::Buffering) { + let buffer = tool_parser.take_buffer(); + if !buffer.is_empty() { + crate::log_warn!( + "[Seq {}] Tool parse partial, flushing {} chars", + current_seq_id, + buffer.len() + ); + stream_ctx.send_token(&buffer); } - ParserState::Normal => {} } } @@ -449,11 +430,12 @@ pub async fn chat_completion( } } - let (valid_calls, invalid_calls) = + let (validated_calls, invalid_calls) = filter_tool_calls(&pending_tool_calls, stream_tool_schemas.as_ref()); + if !invalid_calls.is_empty() { crate::log_warn!( - "[Seq {}] Dropped {} invalid tool call(s)", + "[Seq {}] Found {} invalid tool call(s)", current_seq_id, invalid_calls.len() ); @@ -463,11 +445,44 @@ pub async fn chat_completion( } } + let strict_mode = std::env::var("VLLM_RS_STRICT_TOOL_CALL") + .map(|v| v == "1" || v.eq_ignore_ascii_case("true")) + .unwrap_or(false); + + let valid_calls = if strict_mode { + if !invalid_calls.is_empty() { + crate::log_warn!( + "[Seq {}] Strict mode enabled, dropping invalid calls", + current_seq_id + ); + } + validated_calls + } else { + if !invalid_calls.is_empty() { + crate::log_warn!( + "[Seq {}] Strict mode disabled, keeping invalid calls", + current_seq_id + ); + } + pending_tool_calls + }; + let tool_calls = if valid_calls.is_empty() { None } else { log_tool_calls("Valid", &valid_calls); - Some(valid_calls) + Some( + valid_calls + .into_iter() + .enumerate() + .map(|(i, tc)| crate::server::PublicToolCall { + index: Some(i), + id: tc.id, + type_: tc.tool_type, + function: tc.function, + }) + .collect(), + ) }; let has_any_tool_calls = tool_calls.is_some(); if tool_choice_required && !has_any_tool_calls { @@ -653,10 +668,17 @@ pub async fn chat_completion( total_decoded_time_taken += decode_time_taken; // Parse tool calls from the model output if tools were provided - let tool_parser = ToolParser::new(); - let (content, tool_calls) = if has_tools { - let mut parsed_calls = tool_parser.parse(&output.decode_output); + let tool_parser = StreamToolParser::new_with_config( + &model_type, + parser_model_id.clone(), + tool_config.clone(), + resolved_tools.clone(), + enforce_parser.clone(), + ); + let mut parsed_calls = tool_parser + .parse_complete_with_fallback(&output.decode_output) + .await; if let Some(ref forced_name) = forced_tool_name { let before = parsed_calls.len(); parsed_calls.retain(|call| call.function.name == *forced_name); @@ -669,12 +691,29 @@ pub async fn chat_completion( ); } } - let (valid_calls, invalid_calls) = + let (validated_calls, invalid_calls) = filter_tool_calls(&parsed_calls, tool_schemas.as_ref()); + if !invalid_calls.is_empty() { - crate::log_warn!("Dropped {} invalid tool call(s)", invalid_calls.len()); + crate::log_warn!("Found {} invalid tool call(s)", invalid_calls.len()); log_tool_calls("Invalid", &invalid_calls); } + + let strict_mode = std::env::var("VLLM_RS_STRICT_TOOL_CALL") + .map(|v| v == "1" || v.eq_ignore_ascii_case("true")) + .unwrap_or(false); + + let valid_calls = if strict_mode { + if !invalid_calls.is_empty() { + crate::log_warn!("Strict mode enabled, dropping invalid calls"); + } + validated_calls + } else { + if !invalid_calls.is_empty() { + crate::log_warn!("Strict mode disabled, keeping invalid calls"); + } + parsed_calls + }; if valid_calls.is_empty() { if tool_choice_required { crate::log_warn!("Tool choice required but no tool calls were produced"); @@ -682,7 +721,16 @@ pub async fn chat_completion( (Some(output.decode_output), None) } else { log_tool_calls("Valid", &valid_calls); - (None, Some(valid_calls)) + let public_calls = valid_calls + .into_iter() + .map(|tc| crate::server::PublicToolCall { + index: None, + id: tc.id, + type_: tc.tool_type, + function: tc.function, + }) + .collect(); + (None, Some(public_calls)) } } else { (Some(output.decode_output), None) diff --git a/src/tools/helpers.rs b/src/tools/helpers.rs index 69c7a861..b7bbedbc 100644 --- a/src/tools/helpers.rs +++ b/src/tools/helpers.rs @@ -48,14 +48,15 @@ pub fn filter_tool_calls( } }; - let mut parsed_args = match serde_json::from_str::(&call.function.arguments) { + let args_str = call.function.arguments.as_deref().unwrap_or("{}"); + let mut parsed_args = match serde_json::from_str::(args_str) { Ok(value) => value, Err(e) => { crate::log_warn!( "Failed to parse arguments for tool '{}': {}. Args: {}", call.function.name, e, - call.function.arguments + args_str ); invalid.push(call.clone()); continue; @@ -103,15 +104,14 @@ pub fn filter_tool_calls( ); invalid.push(call.clone()); } else { - let normalized_args = serde_json::to_string(&filtered_args) - .unwrap_or_else(|_| call.function.arguments.clone()); + let normalized_args = + serde_json::to_string(&filtered_args).unwrap_or_else(|_| args_str.to_string()); valid.push(ToolCall { - index: call.index, id: call.id.clone(), - call_type: call.call_type.clone(), + tool_type: call.tool_type.clone(), function: FunctionCall { name: call.function.name.clone(), - arguments: normalized_args, + arguments: Some(normalized_args), }, }); } @@ -128,7 +128,12 @@ pub fn format_tool_calls_summary(tool_calls: &[ToolCall]) -> String { tool_calls .iter() .map(|call| { - let args = call.function.arguments.replace('\n', " "); + let args = call + .function + .arguments + .as_deref() + .unwrap_or("") + .replace('\n', " "); let truncated = if args.len() > 160 { let snippet: String = args.chars().take(160).collect(); format!("{}...", snippet) @@ -156,8 +161,8 @@ mod tests { #[test] fn test_resolve_tools_prefers_request() { - let request_tools = vec![Tool::function("test", "desc").build()]; - let mcp_tools = vec![Tool::function("mcp", "mcp desc").build()]; + let request_tools = vec![crate::tools::function_tool("test", "desc").build()]; + let mcp_tools = vec![crate::tools::function_tool("mcp", "mcp desc").build()]; let resolved = resolve_tools(Some(&request_tools), &mcp_tools); assert_eq!(resolved.len(), 1); @@ -166,7 +171,7 @@ mod tests { #[test] fn test_resolve_tools_falls_back_to_mcp() { - let mcp_tools = vec![Tool::function("mcp", "mcp desc").build()]; + let mcp_tools = vec![crate::tools::function_tool("mcp", "mcp desc").build()]; let resolved = resolve_tools(None, &mcp_tools); assert_eq!(resolved.len(), 1); assert_eq!(resolved[0].function.name, "mcp"); @@ -174,7 +179,7 @@ mod tests { #[test] fn test_build_tool_schema_map() { - let tools = vec![Tool::function("test", "desc") + let tools = vec![crate::tools::function_tool("test", "desc") .param("arg1", "string", "desc", true) .build()]; let map = build_tool_schema_map(&tools); diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 474fbde3..c919d30d 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -10,37 +10,9 @@ pub mod schema; use serde::{Deserialize, Serialize}; use serde_json::Value; +use uuid::Uuid; -/// A tool definition following OpenAI's function calling format -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Tool { - /// Type of the tool, always "function" for now - #[serde(rename = "type")] - pub tool_type: String, - /// The function definition - pub function: FunctionDefinition, -} - -impl Tool { - /// Create a new function tool - pub fn function(name: impl Into, description: impl Into) -> ToolBuilder { - ToolBuilder::new(name.into(), description.into()) - } -} - -/// Definition of a callable function -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FunctionDefinition { - /// Name of the function - pub name: String, - /// Description of what the function does - pub description: String, - /// JSON Schema for the function parameters - pub parameters: Value, - /// Whether to enable strict schema adherence - #[serde(skip_serializing_if = "Option::is_none")] - pub strict: Option, -} +pub use openai_protocol::common::{Function, FunctionCallResponse as FunctionCall, Tool, ToolCall}; /// Builder for creating Tool definitions pub struct ToolBuilder { @@ -105,9 +77,9 @@ impl ToolBuilder { pub fn build(self) -> Tool { Tool { tool_type: "function".to_string(), - function: FunctionDefinition { + function: Function { name: self.name, - description: self.description, + description: Some(self.description), parameters: self.parameters, strict: self.strict, }, @@ -115,6 +87,11 @@ impl ToolBuilder { } } +/// Create a new function tool builder (replacement for Tool::function). +pub fn function_tool(name: impl Into, description: impl Into) -> ToolBuilder { + ToolBuilder::new(name.into(), description.into()) +} + /// Tool choice configuration #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(untagged)] @@ -153,52 +130,32 @@ pub struct ToolChoiceFunction { pub name: String, } -/// A tool call made by the model -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolCall { - /// Index of this tool call in the tool_calls array (streaming only) - #[serde(skip_serializing_if = "Option::is_none")] - pub index: Option, - /// Unique identifier for this tool call - pub id: String, - /// Type of tool call (always "function") - #[serde(rename = "type")] - pub call_type: String, - /// The function call details - pub function: FunctionCall, -} - -impl ToolCall { - /// Create a new tool call - pub fn new( - id: impl Into, - name: impl Into, - arguments: impl Into, - ) -> Self { - Self { - index: None, - id: id.into(), - call_type: "function".to_string(), - function: FunctionCall { - name: name.into(), - arguments: arguments.into(), - }, - } - } - - pub fn with_index(mut self, index: usize) -> Self { - self.index = Some(index); - self +/// Build a ToolCall from name/arguments with a provided ID. +pub fn new_tool_call( + id: impl Into, + name: impl Into, + arguments: impl Into, +) -> ToolCall { + ToolCall { + id: id.into(), + tool_type: "function".to_string(), + function: FunctionCall { + name: name.into(), + arguments: Some(arguments.into()), + }, } } -/// Details of a function call -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FunctionCall { - /// Name of the function to call - pub name: String, - /// JSON string of arguments - pub arguments: String, +/// Convert a parsed tool call into an OpenAI-compatible ToolCall. +pub fn tool_call_from_parser(parsed: tool_parser::ToolCall) -> ToolCall { + ToolCall { + id: format!("call_{}", Uuid::new_v4().simple()), + tool_type: "function".to_string(), + function: FunctionCall { + name: parsed.function.name, + arguments: Some(parsed.function.arguments), + }, + } } /// Result of a tool execution diff --git a/src/tools/parser.rs b/src/tools/parser.rs index 7363dc0e..2eebdbd6 100644 --- a/src/tools/parser.rs +++ b/src/tools/parser.rs @@ -3,7 +3,7 @@ //! //! Supports multiple formats used by different models. -use super::ToolCall; +use super::{new_tool_call, ToolCall}; use regex::Regex; use serde_json::Value; @@ -226,7 +226,7 @@ impl ToolParser { }; *call_id += 1; - Some(ToolCall::new( + Some(new_tool_call( format!("call_{}", call_id), name.to_string(), args_str, @@ -284,7 +284,12 @@ mod tests { let calls = parser.parse(text); assert_eq!(calls.len(), 1); assert_eq!(calls[0].function.name, "get_weather"); - assert!(calls[0].function.arguments.contains("Tokyo")); + assert!(calls[0] + .function + .arguments + .as_deref() + .unwrap_or("") + .contains("Tokyo")); } #[test] diff --git a/src/utils/config.rs b/src/utils/config.rs index 47bcc47d..7ac4baa2 100644 --- a/src/utils/config.rs +++ b/src/utils/config.rs @@ -224,6 +224,7 @@ pub struct EngineConfig { pub model_id: Option, pub weight_path: Option, pub weight_file: Option, + pub enforce_parser: Option, pub hf_token: Option, pub hf_token_path: Option, pub num_blocks: usize, @@ -266,6 +267,8 @@ pub struct EngineConfig { #[pyo3(get, set)] pub weight_file: Option, #[pyo3(get, set)] + pub enforce_parser: Option, + #[pyo3(get, set)] pub hf_token: Option, #[pyo3(get, set)] pub hf_token_path: Option, @@ -329,6 +332,7 @@ impl EngineConfig { weight_file: Option, hf_token: Option, hf_token_path: Option, + enforce_parser: Option, max_num_seqs: Option, config_model_len: Option, max_model_len: Option, @@ -370,6 +374,7 @@ impl EngineConfig { weight_file, hf_token, hf_token_path, + enforce_parser, num_blocks: 128, //placeholder cpu_mem_fold, kv_fraction, diff --git a/vllm_rs.pyi b/vllm_rs.pyi index 65f4251f..ff8fd9d4 100644 --- a/vllm_rs.pyi +++ b/vllm_rs.pyi @@ -52,6 +52,7 @@ class EngineConfig: weight_file: Optional[str] hf_token: Optional[str] hf_token_path: Optional[str] + enforce_parser: Optional[str] tokenizer: Optional[str] tokenizer_config: Optional[str] num_blocks: Optional[int]