Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions vllm_mlx/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,34 @@ class ResponseFormat(BaseModel):
json_schema: ResponseFormatJsonSchema | None = None


# =============================================================================
# Logprobs
# =============================================================================


class TopLogProb(BaseModel):
"""A top log probability for a token."""

token: str
logprob: float
bytes: list[int] | None = None


class TokenLogProb(BaseModel):
"""Log probability information for a single token."""

token: str
logprob: float
bytes: list[int] | None = None
top_logprobs: list[TopLogProb] = []


class ChoiceLogProbs(BaseModel):
"""Log probability information for a choice."""

content: list[TokenLogProb] | None = None


# =============================================================================
# Chat Completion
# =============================================================================
Expand Down Expand Up @@ -169,6 +197,9 @@ class ChatCompletionRequest(BaseModel):
tool_choice: str | dict | None = None # "auto", "none", or specific tool
# Structured output
response_format: ResponseFormat | dict | None = None
# Logprobs
logprobs: bool | None = None
top_logprobs: int | None = None # 0-20, per OpenAI spec
# MLLM-specific parameters
video_fps: float | None = None
video_max_frames: int | None = None
Expand Down Expand Up @@ -199,6 +230,7 @@ class ChatCompletionChoice(BaseModel):
index: int = 0
message: AssistantMessage
finish_reason: str | None = "stop"
logprobs: ChoiceLogProbs | None = None


class Usage(BaseModel):
Expand Down Expand Up @@ -235,6 +267,9 @@ class CompletionRequest(BaseModel):
max_tokens: int | None = None
stream: bool = False
stop: list[str] | None = None
# Logprobs
logprobs: bool | None = None
top_logprobs: int | None = None # 0-20, per OpenAI spec
# Request timeout in seconds (None = use server default)
timeout: float | None = None

Expand All @@ -245,6 +280,7 @@ class CompletionChoice(BaseModel):
index: int = 0
text: str
finish_reason: str | None = "stop"
logprobs: ChoiceLogProbs | None = None


class CompletionResponse(BaseModel):
Expand Down Expand Up @@ -438,6 +474,7 @@ class ChatCompletionChunkChoice(BaseModel):
index: int = 0
delta: ChatCompletionChunkDelta
finish_reason: str | None = None
logprobs: ChoiceLogProbs | None = None


class ChatCompletionChunk(BaseModel):
Expand Down
214 changes: 211 additions & 3 deletions vllm_mlx/api/tool_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,35 @@

from __future__ import annotations

import json
import logging
import re
from typing import Any, Protocol

logger = logging.getLogger(__name__)


def _extract_param_schemas(tools: list[dict] | None) -> dict[str, dict]:
"""
Extract parameter JSON schemas from tool definitions.

Returns a dict mapping "tool_name.param_name" -> JSON schema for that parameter.
"""
if not tools:
return {}

schemas: dict[str, dict] = {}
for tool in tools:
func = tool.get("function", tool)
tool_name = func.get("name", "")
params = func.get("parameters", {})
properties = params.get("properties", {})
for param_name, param_schema in properties.items():
key = f"{tool_name}.{param_name}"
schemas[key] = param_schema
return schemas


class ToolLogitsProcessor(Protocol):
"""Protocol for tool call logits processors."""

Expand Down Expand Up @@ -73,16 +96,23 @@ class MiniMaxToolLogitsProcessor:
("</minimax:tool_call>", "</invoke>"),
]

def __init__(self, tokenizer: Any, bias_strength: float = 20.0):
def __init__(
self,
tokenizer: Any,
bias_strength: float = 20.0,
tool_schemas: dict[str, dict] | None = None,
):
"""
Initialize the MiniMax tool logits processor.

Args:
tokenizer: The tokenizer to use for encoding patterns.
bias_strength: Logits bias to add to expected tokens.
tool_schemas: Map of "tool.param" -> JSON schema for parameter value constraint.
"""
self.tokenizer = tokenizer
self.bias_strength = bias_strength
self._tool_schemas = tool_schemas or {}

# Pre-tokenize structural fragments
self._pattern_tokens: dict[str, list[int]] = {}
Expand All @@ -91,6 +121,13 @@ def __init__(self, tokenizer: Any, bias_strength: float = 20.0):
if tokens:
self._pattern_tokens[pattern] = tokens

# Pre-tokenize common JSON structural tokens for parameter value bias
self._json_tokens: dict[str, list[int]] = {}
for char in ['"', '{', '[', ']', '}', ',', ':', 'true', 'false', 'null']:
toks = tokenizer.encode(char, add_special_tokens=False)
if toks:
self._json_tokens[char] = toks

# State tracking
self._recent_text = ""
self._active_pattern: str | None = None
Expand All @@ -99,13 +136,121 @@ def __init__(self, tokenizer: Any, bias_strength: float = 20.0):
self._consecutive_bias_count = 0 # Safety: escape hatch for stuck patterns
self._max_consecutive_bias = 50 # Max tokens to bias before force-resetting

# Parameter value tracking for structural constraint
self._current_tool_name: str | None = None
self._current_param_name: str | None = None
self._in_parameter_value = False
self._param_value_text = "" # Accumulated text of current param value

def reset(self) -> None:
"""Reset state for a new generation."""
self._recent_text = ""
self._active_pattern = None
self._pattern_pos = 0
self._last_param_close_pos = -1
self._consecutive_bias_count = 0
self._current_tool_name = None
self._current_param_name = None
self._in_parameter_value = False
self._param_value_text = ""

# Regex patterns for detecting tool/parameter context
_INVOKE_RE = re.compile(r'<invoke\s+name="([^"]+)"')
_PARAM_OPEN_RE = re.compile(r'<parameter\s+name="([^"]+)">')
_PARAM_CLOSE_RE = re.compile(r'</parameter>')

def _update_param_state(self) -> None:
"""Update parameter value tracking state from recent text."""
text = self._recent_text

# Detect <invoke name="tool_name">
for m in self._INVOKE_RE.finditer(text):
self._current_tool_name = m.group(1)

# Detect <parameter name="param_name"> → entering value
for m in self._PARAM_OPEN_RE.finditer(text):
self._current_param_name = m.group(1)
end_pos = m.end()
# Only activate if this is the latest unclosed parameter
close_after = text.find("</parameter>", end_pos)
if close_after == -1:
# No close tag after this open → we're inside value
self._in_parameter_value = True
self._param_value_text = text[end_pos:]

# Detect </parameter> → leaving value
if self._in_parameter_value:
if "</parameter>" in self._param_value_text or text.rstrip().endswith(
"</parameter>"
):
self._in_parameter_value = False
self._param_value_text = ""

def _apply_param_value_bias(self, logits: Any) -> Any | None:
"""
Apply JSON structural bias when generating a parameter value.

Uses the schema type to bias toward valid JSON tokens:
- string: bias toward quote characters
- number/integer: bias toward digit tokens
- boolean: bias toward 'true'/'false'
- object/array: bias toward opening braces/brackets

Returns biased logits, or None to skip bias (let model generate freely).
"""
import mlx.core as mx

if not self._current_tool_name or not self._current_param_name:
return None

schema_key = f"{self._current_tool_name}.{self._current_param_name}"
schema = self._tool_schemas.get(schema_key)
if not schema:
return None

param_type = schema.get("type", "")
value_text = self._param_value_text.strip()

# Only bias at the START of a value (first meaningful token)
# Once the model has started generating, let it continue freely
if len(value_text) > 2:
return None

bias_tokens: list[int] = []
weak_bias = self.bias_strength * 0.3 # Lighter bias for value guidance

if param_type == "string":
# Strings should start with "
if not value_text:
bias_tokens = self._json_tokens.get('"', [])
elif param_type in ("number", "integer"):
# Numbers: bias toward digit tokens (0-9, -, .)
for ch in "0123456789-.":
toks = self.tokenizer.encode(ch, add_special_tokens=False)
if toks:
bias_tokens.extend(toks)
elif param_type == "boolean":
# Bias toward 'true' and 'false'
for val in ["true", "false"]:
toks = self._json_tokens.get(val, [])
bias_tokens.extend(toks)
elif param_type == "object":
if not value_text:
bias_tokens = self._json_tokens.get("{", [])
elif param_type == "array":
if not value_text:
bias_tokens = self._json_tokens.get("[", [])

if not bias_tokens:
return None

bias = mx.zeros_like(logits)
for tok in bias_tokens:
if logits.ndim == 2:
bias[0, tok] = weak_bias
else:
bias[tok] = weak_bias
return logits + bias

def __call__(self, token_ids: Any, logits: Any) -> Any:
"""
Expand Down Expand Up @@ -149,6 +294,15 @@ def __call__(self, token_ids: Any, logits: Any) -> Any:
if len(self._recent_text) > 200:
self._recent_text = self._recent_text[-200:]

# --- Parameter value state tracking ---
self._update_param_state()

# If inside a parameter value, apply JSON structural bias
if self._in_parameter_value and self._tool_schemas:
biased = self._apply_param_value_bias(logits)
if biased is not None:
return biased

# If we're tracking an active pattern, bias toward next token
if self._active_pattern is not None:
pattern_tokens = self._pattern_tokens.get(self._active_pattern, [])
Expand Down Expand Up @@ -219,7 +373,10 @@ def __call__(self, token_ids: Any, logits: Any) -> Any:


def create_tool_logits_processor(
parser_name: str, tokenizer: Any, bias_strength: float = 20.0
parser_name: str,
tokenizer: Any,
bias_strength: float = 20.0,
tools: list[dict] | None = None,
) -> ToolLogitsProcessor | None:
"""
Factory function to create a tool logits processor for a given parser.
Expand All @@ -228,11 +385,62 @@ def create_tool_logits_processor(
parser_name: Name of the tool call parser (e.g., "minimax").
tokenizer: The tokenizer instance.
bias_strength: Logits bias strength.
tools: Optional tool definitions for parameter value schema constraint.

Returns:
A logits processor instance, or None if not supported for this parser.
"""
tool_schemas = _extract_param_schemas(tools)
if parser_name == "minimax":
return MiniMaxToolLogitsProcessor(tokenizer, bias_strength=bias_strength)
return MiniMaxToolLogitsProcessor(
tokenizer,
bias_strength=bias_strength,
tool_schemas=tool_schemas,
)
# Future: add support for other parsers (hermes, llama, etc.)
return None


def validate_param_value(value: str, schema: dict) -> tuple[bool, str | None]:
"""
Validate a parameter value against its JSON schema (lightweight).

Used by SimpleEngine for post-generation validation of tool call parameters.

Args:
value: The parameter value string.
schema: JSON schema for the parameter.

Returns:
(is_valid, error_message) tuple.
"""
param_type = schema.get("type", "")

# Try to parse as JSON first
try:
parsed = json.loads(value)
except (json.JSONDecodeError, ValueError):
# Not valid JSON — check if it's a bare string (common for string params)
if param_type == "string":
return True, None # Bare strings are acceptable for string params
return False, f"Invalid JSON value: {value!r}"

# Type check
if param_type == "string" and not isinstance(parsed, str):
return False, f"Expected string, got {type(parsed).__name__}"
elif param_type == "integer" and not isinstance(parsed, int):
return False, f"Expected integer, got {type(parsed).__name__}"
elif param_type == "number" and not isinstance(parsed, (int, float)):
return False, f"Expected number, got {type(parsed).__name__}"
elif param_type == "boolean" and not isinstance(parsed, bool):
return False, f"Expected boolean, got {type(parsed).__name__}"
elif param_type == "array" and not isinstance(parsed, list):
return False, f"Expected array, got {type(parsed).__name__}"
elif param_type == "object" and not isinstance(parsed, dict):
return False, f"Expected object, got {type(parsed).__name__}"

# Enum check
if "enum" in schema and parsed not in schema["enum"]:
return False, f"Value {parsed!r} not in enum {schema['enum']}"

return True, None
2 changes: 2 additions & 0 deletions vllm_mlx/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class GenerationOutput:
# For streaming
new_text: str = ""
finished: bool = True
# Per-token logprobs (mx.array of shape [vocab_size] for current token)
logprobs: Any = None


class BaseEngine(ABC):
Expand Down
5 changes: 5 additions & 0 deletions vllm_mlx/engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,18 @@ async def stream_generate(
if finished:
finish_reason = getattr(chunk, "finish_reason", "stop")

# Pass current token ID for logprobs extraction
current_token = getattr(chunk, "token", 0)

yield GenerationOutput(
text=accumulated_text,
new_text=new_text,
tokens=[current_token] if current_token else [],
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
finished=finished,
finish_reason=finish_reason,
logprobs=getattr(chunk, "logprobs", None),
)

if finished:
Expand Down
Loading