diff --git a/vllm_mlx/api/models.py b/vllm_mlx/api/models.py index aa7cb245..1336d67c 100644 --- a/vllm_mlx/api/models.py +++ b/vllm_mlx/api/models.py @@ -140,6 +140,34 @@ class ResponseFormat(BaseModel): json_schema: ResponseFormatJsonSchema | None = None +# ============================================================================= +# Logprobs +# ============================================================================= + + +class TopLogProb(BaseModel): + """A top log probability for a token.""" + + token: str + logprob: float + bytes: list[int] | None = None + + +class TokenLogProb(BaseModel): + """Log probability information for a single token.""" + + token: str + logprob: float + bytes: list[int] | None = None + top_logprobs: list[TopLogProb] = [] + + +class ChoiceLogProbs(BaseModel): + """Log probability information for a choice.""" + + content: list[TokenLogProb] | None = None + + # ============================================================================= # Chat Completion # ============================================================================= @@ -169,6 +197,9 @@ class ChatCompletionRequest(BaseModel): tool_choice: str | dict | None = None # "auto", "none", or specific tool # Structured output response_format: ResponseFormat | dict | None = None + # Logprobs + logprobs: bool | None = None + top_logprobs: int | None = None # 0-20, per OpenAI spec # MLLM-specific parameters video_fps: float | None = None video_max_frames: int | None = None @@ -199,6 +230,7 @@ class ChatCompletionChoice(BaseModel): index: int = 0 message: AssistantMessage finish_reason: str | None = "stop" + logprobs: ChoiceLogProbs | None = None class Usage(BaseModel): @@ -235,6 +267,9 @@ class CompletionRequest(BaseModel): max_tokens: int | None = None stream: bool = False stop: list[str] | None = None + # Logprobs + logprobs: bool | None = None + top_logprobs: int | None = None # 0-20, per OpenAI spec # Request timeout in seconds (None = use server default) timeout: float | None = None @@ -245,6 +280,7 @@ class CompletionChoice(BaseModel): index: int = 0 text: str finish_reason: str | None = "stop" + logprobs: ChoiceLogProbs | None = None class CompletionResponse(BaseModel): @@ -438,6 +474,7 @@ class ChatCompletionChunkChoice(BaseModel): index: int = 0 delta: ChatCompletionChunkDelta finish_reason: str | None = None + logprobs: ChoiceLogProbs | None = None class ChatCompletionChunk(BaseModel): diff --git a/vllm_mlx/api/tool_logits.py b/vllm_mlx/api/tool_logits.py index 854af08d..575f98b3 100644 --- a/vllm_mlx/api/tool_logits.py +++ b/vllm_mlx/api/tool_logits.py @@ -17,12 +17,35 @@ from __future__ import annotations +import json import logging +import re from typing import Any, Protocol logger = logging.getLogger(__name__) +def _extract_param_schemas(tools: list[dict] | None) -> dict[str, dict]: + """ + Extract parameter JSON schemas from tool definitions. + + Returns a dict mapping "tool_name.param_name" -> JSON schema for that parameter. + """ + if not tools: + return {} + + schemas: dict[str, dict] = {} + for tool in tools: + func = tool.get("function", tool) + tool_name = func.get("name", "") + params = func.get("parameters", {}) + properties = params.get("properties", {}) + for param_name, param_schema in properties.items(): + key = f"{tool_name}.{param_name}" + schemas[key] = param_schema + return schemas + + class ToolLogitsProcessor(Protocol): """Protocol for tool call logits processors.""" @@ -73,16 +96,23 @@ class MiniMaxToolLogitsProcessor: ("", ""), ] - def __init__(self, tokenizer: Any, bias_strength: float = 20.0): + def __init__( + self, + tokenizer: Any, + bias_strength: float = 20.0, + tool_schemas: dict[str, dict] | None = None, + ): """ Initialize the MiniMax tool logits processor. Args: tokenizer: The tokenizer to use for encoding patterns. bias_strength: Logits bias to add to expected tokens. + tool_schemas: Map of "tool.param" -> JSON schema for parameter value constraint. """ self.tokenizer = tokenizer self.bias_strength = bias_strength + self._tool_schemas = tool_schemas or {} # Pre-tokenize structural fragments self._pattern_tokens: dict[str, list[int]] = {} @@ -91,6 +121,13 @@ def __init__(self, tokenizer: Any, bias_strength: float = 20.0): if tokens: self._pattern_tokens[pattern] = tokens + # Pre-tokenize common JSON structural tokens for parameter value bias + self._json_tokens: dict[str, list[int]] = {} + for char in ['"', '{', '[', ']', '}', ',', ':', 'true', 'false', 'null']: + toks = tokenizer.encode(char, add_special_tokens=False) + if toks: + self._json_tokens[char] = toks + # State tracking self._recent_text = "" self._active_pattern: str | None = None @@ -99,6 +136,12 @@ def __init__(self, tokenizer: Any, bias_strength: float = 20.0): self._consecutive_bias_count = 0 # Safety: escape hatch for stuck patterns self._max_consecutive_bias = 50 # Max tokens to bias before force-resetting + # Parameter value tracking for structural constraint + self._current_tool_name: str | None = None + self._current_param_name: str | None = None + self._in_parameter_value = False + self._param_value_text = "" # Accumulated text of current param value + def reset(self) -> None: """Reset state for a new generation.""" self._recent_text = "" @@ -106,6 +149,108 @@ def reset(self) -> None: self._pattern_pos = 0 self._last_param_close_pos = -1 self._consecutive_bias_count = 0 + self._current_tool_name = None + self._current_param_name = None + self._in_parameter_value = False + self._param_value_text = "" + + # Regex patterns for detecting tool/parameter context + _INVOKE_RE = re.compile(r'') + _PARAM_CLOSE_RE = re.compile(r'') + + def _update_param_state(self) -> None: + """Update parameter value tracking state from recent text.""" + text = self._recent_text + + # Detect + for m in self._INVOKE_RE.finditer(text): + self._current_tool_name = m.group(1) + + # Detect → entering value + for m in self._PARAM_OPEN_RE.finditer(text): + self._current_param_name = m.group(1) + end_pos = m.end() + # Only activate if this is the latest unclosed parameter + close_after = text.find("", end_pos) + if close_after == -1: + # No close tag after this open → we're inside value + self._in_parameter_value = True + self._param_value_text = text[end_pos:] + + # Detect → leaving value + if self._in_parameter_value: + if "" in self._param_value_text or text.rstrip().endswith( + "" + ): + self._in_parameter_value = False + self._param_value_text = "" + + def _apply_param_value_bias(self, logits: Any) -> Any | None: + """ + Apply JSON structural bias when generating a parameter value. + + Uses the schema type to bias toward valid JSON tokens: + - string: bias toward quote characters + - number/integer: bias toward digit tokens + - boolean: bias toward 'true'/'false' + - object/array: bias toward opening braces/brackets + + Returns biased logits, or None to skip bias (let model generate freely). + """ + import mlx.core as mx + + if not self._current_tool_name or not self._current_param_name: + return None + + schema_key = f"{self._current_tool_name}.{self._current_param_name}" + schema = self._tool_schemas.get(schema_key) + if not schema: + return None + + param_type = schema.get("type", "") + value_text = self._param_value_text.strip() + + # Only bias at the START of a value (first meaningful token) + # Once the model has started generating, let it continue freely + if len(value_text) > 2: + return None + + bias_tokens: list[int] = [] + weak_bias = self.bias_strength * 0.3 # Lighter bias for value guidance + + if param_type == "string": + # Strings should start with " + if not value_text: + bias_tokens = self._json_tokens.get('"', []) + elif param_type in ("number", "integer"): + # Numbers: bias toward digit tokens (0-9, -, .) + for ch in "0123456789-.": + toks = self.tokenizer.encode(ch, add_special_tokens=False) + if toks: + bias_tokens.extend(toks) + elif param_type == "boolean": + # Bias toward 'true' and 'false' + for val in ["true", "false"]: + toks = self._json_tokens.get(val, []) + bias_tokens.extend(toks) + elif param_type == "object": + if not value_text: + bias_tokens = self._json_tokens.get("{", []) + elif param_type == "array": + if not value_text: + bias_tokens = self._json_tokens.get("[", []) + + if not bias_tokens: + return None + + bias = mx.zeros_like(logits) + for tok in bias_tokens: + if logits.ndim == 2: + bias[0, tok] = weak_bias + else: + bias[tok] = weak_bias + return logits + bias def __call__(self, token_ids: Any, logits: Any) -> Any: """ @@ -149,6 +294,15 @@ def __call__(self, token_ids: Any, logits: Any) -> Any: if len(self._recent_text) > 200: self._recent_text = self._recent_text[-200:] + # --- Parameter value state tracking --- + self._update_param_state() + + # If inside a parameter value, apply JSON structural bias + if self._in_parameter_value and self._tool_schemas: + biased = self._apply_param_value_bias(logits) + if biased is not None: + return biased + # If we're tracking an active pattern, bias toward next token if self._active_pattern is not None: pattern_tokens = self._pattern_tokens.get(self._active_pattern, []) @@ -219,7 +373,10 @@ def __call__(self, token_ids: Any, logits: Any) -> Any: def create_tool_logits_processor( - parser_name: str, tokenizer: Any, bias_strength: float = 20.0 + parser_name: str, + tokenizer: Any, + bias_strength: float = 20.0, + tools: list[dict] | None = None, ) -> ToolLogitsProcessor | None: """ Factory function to create a tool logits processor for a given parser. @@ -228,11 +385,62 @@ def create_tool_logits_processor( parser_name: Name of the tool call parser (e.g., "minimax"). tokenizer: The tokenizer instance. bias_strength: Logits bias strength. + tools: Optional tool definitions for parameter value schema constraint. Returns: A logits processor instance, or None if not supported for this parser. """ + tool_schemas = _extract_param_schemas(tools) if parser_name == "minimax": - return MiniMaxToolLogitsProcessor(tokenizer, bias_strength=bias_strength) + return MiniMaxToolLogitsProcessor( + tokenizer, + bias_strength=bias_strength, + tool_schemas=tool_schemas, + ) # Future: add support for other parsers (hermes, llama, etc.) return None + + +def validate_param_value(value: str, schema: dict) -> tuple[bool, str | None]: + """ + Validate a parameter value against its JSON schema (lightweight). + + Used by SimpleEngine for post-generation validation of tool call parameters. + + Args: + value: The parameter value string. + schema: JSON schema for the parameter. + + Returns: + (is_valid, error_message) tuple. + """ + param_type = schema.get("type", "") + + # Try to parse as JSON first + try: + parsed = json.loads(value) + except (json.JSONDecodeError, ValueError): + # Not valid JSON — check if it's a bare string (common for string params) + if param_type == "string": + return True, None # Bare strings are acceptable for string params + return False, f"Invalid JSON value: {value!r}" + + # Type check + if param_type == "string" and not isinstance(parsed, str): + return False, f"Expected string, got {type(parsed).__name__}" + elif param_type == "integer" and not isinstance(parsed, int): + return False, f"Expected integer, got {type(parsed).__name__}" + elif param_type == "number" and not isinstance(parsed, (int, float)): + return False, f"Expected number, got {type(parsed).__name__}" + elif param_type == "boolean" and not isinstance(parsed, bool): + return False, f"Expected boolean, got {type(parsed).__name__}" + elif param_type == "array" and not isinstance(parsed, list): + return False, f"Expected array, got {type(parsed).__name__}" + elif param_type == "object" and not isinstance(parsed, dict): + return False, f"Expected object, got {type(parsed).__name__}" + + # Enum check + if "enum" in schema and parsed not in schema["enum"]: + return False, f"Value {parsed!r} not in enum {schema['enum']}" + + return True, None diff --git a/vllm_mlx/engine/base.py b/vllm_mlx/engine/base.py index fc5e7045..877e7f41 100644 --- a/vllm_mlx/engine/base.py +++ b/vllm_mlx/engine/base.py @@ -25,6 +25,8 @@ class GenerationOutput: # For streaming new_text: str = "" finished: bool = True + # Per-token logprobs (mx.array of shape [vocab_size] for current token) + logprobs: Any = None class BaseEngine(ABC): diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index df42b3e8..3d4514cc 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -235,13 +235,18 @@ async def stream_generate( if finished: finish_reason = getattr(chunk, "finish_reason", "stop") + # Pass current token ID for logprobs extraction + current_token = getattr(chunk, "token", 0) + yield GenerationOutput( text=accumulated_text, new_text=new_text, + tokens=[current_token] if current_token else [], prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, finished=finished, finish_reason=finish_reason, + logprobs=getattr(chunk, "logprobs", None), ) if finished: diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py index c95b8e98..84738db0 100644 --- a/vllm_mlx/models/llm.py +++ b/vllm_mlx/models/llm.py @@ -10,7 +10,7 @@ import logging import time from dataclasses import dataclass -from typing import Iterator +from typing import Any, Iterator import mlx.core as mx @@ -34,6 +34,7 @@ class StreamingOutput: token: int finished: bool = False finish_reason: str | None = None + logprobs: Any = None # mx.array of shape [vocab_size] from mlx-lm class MLXLanguageModel: @@ -401,6 +402,7 @@ def stream_generate( token=response.token if hasattr(response, "token") else 0, finished=finished, finish_reason=finish_reason, + logprobs=getattr(response, "logprobs", None), ) if finished: diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 4c5265a5..b3b5b039 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -69,6 +69,7 @@ ChatCompletionChunkDelta, # noqa: F401 ChatCompletionRequest, ChatCompletionResponse, + ChoiceLogProbs, CompletionChoice, # noqa: F401 CompletionRequest, CompletionResponse, @@ -88,7 +89,9 @@ Message, # noqa: F401 ModelInfo, # noqa: F401 ModelsResponse, + TokenLogProb, ToolCall, + TopLogProb, Usage, # noqa: F401 VideoUrl, # noqa: F401 ) @@ -496,6 +499,51 @@ def _parse_tool_calls_with_parser( return parse_tool_calls(output_text, request_dict) +def _validate_tool_call_params( + tool_calls: list, tools: list +) -> None: + """ + Validate tool call parameter values against their schemas (post-generation). + + Logs warnings for invalid parameters but does not block the response. + This provides graceful degradation for SimpleEngine mode where logits-level + constraint is not available. + """ + from .api.tool_logits import _extract_param_schemas, validate_param_value + + tool_defs = [t.model_dump() if hasattr(t, "model_dump") else t for t in tools] + schemas = _extract_param_schemas(tool_defs) + + for tc in tool_calls: + func = tc.function if hasattr(tc, "function") else tc.get("function", {}) + func_name = func.name if hasattr(func, "name") else func.get("name", "") + args_str = func.arguments if hasattr(func, "arguments") else func.get("arguments", "{}") + + try: + args = json.loads(args_str) + except (json.JSONDecodeError, ValueError): + logger.warning( + f"Tool call '{func_name}': arguments is not valid JSON: {args_str!r}" + ) + continue + + if not isinstance(args, dict): + continue + + for param_name, param_value in args.items(): + schema_key = f"{func_name}.{param_name}" + schema = schemas.get(schema_key) + if not schema: + continue + is_valid, error = validate_param_value( + json.dumps(param_value), schema + ) + if not is_valid: + logger.warning( + f"Tool call '{func_name}' param '{param_name}': {error}" + ) + + def _detect_native_tool_support() -> bool: """ Detect if the active tool parser supports native tool format. @@ -655,9 +703,12 @@ def load_model( tokenizer = _engine.tokenizer if tokenizer is not None: # Create factory that produces fresh processors per request + # Accepts optional tools for parameter value schema constraint def _make_factory(parser_name, tok): - def factory(): - return create_tool_logits_processor(parser_name, tok) + def factory(tools=None): + return create_tool_logits_processor( + parser_name, tok, tools=tools + ) return factory factory = _make_factory(_tool_call_parser, tokenizer) @@ -677,6 +728,59 @@ def factory(): logger.info(f"Default max tokens: {_default_max_tokens}") +def _extract_token_logprob( + logprobs_array, token_id: int, tokenizer, top_k: int +) -> TokenLogProb: + """ + Convert an mx.array of log-probabilities to a TokenLogProb with top-k alternatives. + + Args: + logprobs_array: mx.array of shape [vocab_size] with log-probabilities. + token_id: The actually sampled token ID. + tokenizer: Tokenizer for decoding token IDs to text. + top_k: Number of top alternatives to include. + + Returns: + TokenLogProb with the sampled token and top-k alternatives. + """ + import mlx.core as mx + import numpy as np + + # Convert to float32 first — mx.array may be bfloat16 which numpy can't handle + if hasattr(logprobs_array, "astype"): + logprobs_array = logprobs_array.astype(mx.float32) + probs = np.array(logprobs_array).flatten() + # Get top-k indices + top_k = min(top_k, len(probs)) + top_indices = np.argpartition(probs, -top_k)[-top_k:] + top_indices = top_indices[np.argsort(probs[top_indices])][::-1] + + # Build top_logprobs list + top_logprobs = [] + for idx in top_indices: + idx = int(idx) + tok_text = tokenizer.decode([idx]) + tok_bytes = list(tok_text.encode("utf-8", errors="replace")) + top_logprobs.append( + TopLogProb( + token=tok_text, + logprob=float(probs[idx]), + bytes=tok_bytes, + ) + ) + + # The sampled token + sampled_text = tokenizer.decode([token_id]) + sampled_bytes = list(sampled_text.encode("utf-8", errors="replace")) + + return TokenLogProb( + token=sampled_text, + logprob=float(probs[token_id]) if token_id < len(probs) else 0.0, + bytes=sampled_bytes, + top_logprobs=top_logprobs, + ) + + def get_usage(output: GenerationOutput) -> Usage: """Extract usage metrics from GenerationOutput.""" total_prompt_tokens = ( @@ -1467,6 +1571,15 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re """ engine = get_engine() + # Validate top_logprobs range (OpenAI spec: 0-20) + if request.top_logprobs is not None and ( + request.top_logprobs < 0 or request.top_logprobs > 20 + ): + raise HTTPException( + status_code=400, + detail="top_logprobs must be between 0 and 20", + ) + # --- Detailed request logging --- n_msgs = len(request.messages) msg_roles = [m.role for m in request.messages] @@ -1564,6 +1677,11 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re if _gc_control and gc_was_enabled: gc.disable() + # Determine if we need per-token logprobs + want_logprobs = request.logprobs and request.top_logprobs + top_k_logprobs = request.top_logprobs or 0 + token_logprobs_list: list[TokenLogProb] = [] + # Check if we should use guided generation for JSON schema use_guided = False json_schema = None @@ -1575,7 +1693,21 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re logger.info("Using guided generation for JSON schema enforcement") try: - if use_guided and json_schema: + if want_logprobs and not use_guided: + # Use streaming internally to collect per-token logprobs + output = None + async for chunk in engine.stream_chat(messages=messages, **chat_kwargs): + output = chunk + if chunk.logprobs is not None and chunk.new_text: + token_id = chunk.tokens[-1] if chunk.tokens else 0 + token_logprobs_list.append( + _extract_token_logprob( + chunk.logprobs, token_id, engine.tokenizer, top_k_logprobs + ) + ) + if output is None: + return Response(status_code=499) + elif use_guided and json_schema: # Use guided generation for constrained JSON output # Fall back to standard generation if guided fails (bad schema, etc.) try: @@ -1622,6 +1754,10 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re # Parse tool calls from output using configured parser cleaned_text, tool_calls = _parse_tool_calls_with_parser(output.text, request) + # Validate tool call parameter values against schemas (SimpleEngine post-generation check) + if tool_calls and request.tools: + _validate_tool_call_params(tool_calls, request.tools) + # Extract reasoning content FIRST (strips channel tokens before JSON extraction) reasoning_text = None if _reasoning_parser and not tool_calls: @@ -1657,6 +1793,11 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re # This handles Qwen3 reasoning mode: "Let me think... {json}" final_content = extract_json_from_response(final_content) + # Build logprobs for response if requested + choice_logprobs = None + if want_logprobs and token_logprobs_list: + choice_logprobs = ChoiceLogProbs(content=token_logprobs_list) + return ChatCompletionResponse( model=request.model, choices=[ @@ -1667,6 +1808,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re tool_calls=tool_calls, ), finish_reason=finish_reason, + logprobs=choice_logprobs, ) ], usage=Usage( @@ -2112,6 +2254,20 @@ async def stream_chat_completion( # Check if we should include usage in the final chunk include_usage = request.stream_options and request.stream_options.include_usage + # Logprobs configuration + want_logprobs = request.logprobs and request.top_logprobs + top_k_logprobs = request.top_logprobs or 0 + + def _build_chunk_logprobs(output: GenerationOutput) -> ChoiceLogProbs | None: + """Build ChoiceLogProbs for a streaming chunk if logprobs requested.""" + if not want_logprobs or output.logprobs is None: + return None + token_id = output.tokens[-1] if output.tokens else 0 + token_lp = _extract_token_logprob( + output.logprobs, token_id, engine.tokenizer, top_k_logprobs + ) + return ChoiceLogProbs(content=[token_lp]) + # First chunk with role first_chunk = ChatCompletionChunk( id=response_id, @@ -2279,6 +2435,7 @@ async def stream_chat_completion( if (output.finished and tool_calls_detected) else (output.finish_reason if output.finished else None) ), + logprobs=_build_chunk_logprobs(output), ) ], usage=get_usage(output) if output.finished else None, @@ -2355,6 +2512,7 @@ async def stream_chat_completion( if (output.finished and tool_calls_detected) else (output.finish_reason if output.finished else None) ), + logprobs=_build_chunk_logprobs(output), ) ], usage=get_usage(output) if output.finished else None,