Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 99 additions & 3 deletions src/utils/vllm_backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
import json
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Optional
from typing import Optional, Union

try:
from interegular.patterns import parse_pattern
except ImportError:
parse_pattern = None

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
Expand Down Expand Up @@ -89,10 +94,33 @@ def from_dict(
str: str,
Optional[int]: int,
}
for key, value in params_dict.items():

# Remove None values to let vLLM use defaults
params_dict = {k: v for k, v in params_dict.items() if v is not None}

for key, value in list(params_dict.items()):
if key == "structured_outputs":
params_dict[key] = StructuredOutputsParams(**json.loads(value))
elif key in vllm_params_dict:
elif key == "guided_decoding":
if isinstance(value, str):
value = json.loads(value)

# Map guided_decoding to structured_outputs
# Remove backend if present as it is not supported in StructuredOutputsParams constructor
if "backend" in value:
backend = value.pop("backend")
if backend not in ["xgrammar", "xgrammar:no-fallback", "auto", "xgrammar:_auto"]:
raise ValueError(f"guided_decoding.backend is no longer supported request-level. Provided: {backend}")

# If structured_outputs is not already set, use guided_decoding params
if "structured_outputs" not in params_dict:
params_dict["structured_outputs"] = StructuredOutputsParams(**value)

if "guided_decoding" in params_dict:
del params_dict["guided_decoding"]

for key, value in params_dict.items():
if key in vllm_params_dict:
vllm_type = vllm_params_dict[key]
if vllm_type in type_mapping:
params_dict[key] = type_mapping[vllm_type](params_dict[key])
Expand All @@ -105,6 +133,74 @@ def from_dict(
)
return None

def __post_init__(self):
super().__post_init__()

# Validate the structured outputs parameters.
if self.structured_outputs:
if not isinstance(self.structured_outputs, StructuredOutputsParams):
raise ValueError(
"structured_outputs must be of type StructuredOutputsParams"
)
TritonSamplingParams._validate_guided_params(self.structured_outputs)

@staticmethod
def _validate_guided_params(params: StructuredOutputsParams):
"""
Validates the structured outputs parameters.
Raises an exception if the parameters are invalid.
"""
if not params:
return

if not isinstance(params, StructuredOutputsParams):
raise ValueError("structured_outputs must be of type StructuredOutputsParams")

# Validate regex constraint if provided.
if params.regex:
if not isinstance(params.regex, str):
raise ValueError("structured_outputs.regex must be a string")
if parse_pattern:
try:
parse_pattern(params.regex)
except Exception as e:
raise ValueError(f"Invalid regex constraint: {e}") from e

# backend validation is removed as it is not exposed in StructuredOutputsParams
# and handled during construction/mapping.

if params.grammar:
if not isinstance(params.grammar, str):
raise ValueError("grammar must be a string, describing a BNF grammar")

try:
from xgrammar import \
Grammar # type: ignore[import]: do NOT move up to avoid premature CUDA init

# Try to parse the converted grammar, to fail this request early
Grammar.from_ebnf(params.grammar)
except ImportError:
pass
except RuntimeError as e:
raise ValueError(f"Invalid BNF grammar: {e}") from e

# Validate choice constraint.
if params.choice:
if not isinstance(params.choice, list):
raise ValueError("choice must be a list")
for item in params.choice:
if not isinstance(item, str):
raise ValueError("Each element in choice must be a string")

# Validate JSON constraint.
if params.json:
if not isinstance(params.json, dict):
raise ValueError("json must be a JSON schema dictionary")

# Validate whitespace_pattern constraint.
if params.whitespace_pattern:
if not isinstance(params.whitespace_pattern, str):
raise ValueError("whitespace_pattern must be a string")

# Copy from vllm/vllm/entrypoints/openai/api_server.py with custom stat_loggers
@asynccontextmanager
Expand Down