diff --git a/src/utils/vllm_backend_utils.py b/src/utils/vllm_backend_utils.py index c7596d6d..2c6cd5b9 100644 --- a/src/utils/vllm_backend_utils.py +++ b/src/utils/vllm_backend_utils.py @@ -27,7 +27,12 @@ import json from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Optional +from typing import Optional, Union + +try: + from interegular.patterns import parse_pattern +except ImportError: + parse_pattern = None from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient @@ -89,10 +94,33 @@ def from_dict( str: str, Optional[int]: int, } - for key, value in params_dict.items(): + + # Remove None values to let vLLM use defaults + params_dict = {k: v for k, v in params_dict.items() if v is not None} + + for key, value in list(params_dict.items()): if key == "structured_outputs": params_dict[key] = StructuredOutputsParams(**json.loads(value)) - elif key in vllm_params_dict: + elif key == "guided_decoding": + if isinstance(value, str): + value = json.loads(value) + + # Map guided_decoding to structured_outputs + # Remove backend if present as it is not supported in StructuredOutputsParams constructor + if "backend" in value: + backend = value.pop("backend") + if backend not in ["xgrammar", "xgrammar:no-fallback", "auto", "xgrammar:_auto"]: + raise ValueError(f"guided_decoding.backend is no longer supported request-level. Provided: {backend}") + + # If structured_outputs is not already set, use guided_decoding params + if "structured_outputs" not in params_dict: + params_dict["structured_outputs"] = StructuredOutputsParams(**value) + + if "guided_decoding" in params_dict: + del params_dict["guided_decoding"] + + for key, value in params_dict.items(): + if key in vllm_params_dict: vllm_type = vllm_params_dict[key] if vllm_type in type_mapping: params_dict[key] = type_mapping[vllm_type](params_dict[key]) @@ -105,6 +133,74 @@ def from_dict( ) return None + def __post_init__(self): + super().__post_init__() + + # Validate the structured outputs parameters. + if self.structured_outputs: + if not isinstance(self.structured_outputs, StructuredOutputsParams): + raise ValueError( + "structured_outputs must be of type StructuredOutputsParams" + ) + TritonSamplingParams._validate_guided_params(self.structured_outputs) + + @staticmethod + def _validate_guided_params(params: StructuredOutputsParams): + """ + Validates the structured outputs parameters. + Raises an exception if the parameters are invalid. + """ + if not params: + return + + if not isinstance(params, StructuredOutputsParams): + raise ValueError("structured_outputs must be of type StructuredOutputsParams") + + # Validate regex constraint if provided. + if params.regex: + if not isinstance(params.regex, str): + raise ValueError("structured_outputs.regex must be a string") + if parse_pattern: + try: + parse_pattern(params.regex) + except Exception as e: + raise ValueError(f"Invalid regex constraint: {e}") from e + + # backend validation is removed as it is not exposed in StructuredOutputsParams + # and handled during construction/mapping. + + if params.grammar: + if not isinstance(params.grammar, str): + raise ValueError("grammar must be a string, describing a BNF grammar") + + try: + from xgrammar import \ + Grammar # type: ignore[import]: do NOT move up to avoid premature CUDA init + + # Try to parse the converted grammar, to fail this request early + Grammar.from_ebnf(params.grammar) + except ImportError: + pass + except RuntimeError as e: + raise ValueError(f"Invalid BNF grammar: {e}") from e + + # Validate choice constraint. + if params.choice: + if not isinstance(params.choice, list): + raise ValueError("choice must be a list") + for item in params.choice: + if not isinstance(item, str): + raise ValueError("Each element in choice must be a string") + + # Validate JSON constraint. + if params.json: + if not isinstance(params.json, dict): + raise ValueError("json must be a JSON schema dictionary") + + # Validate whitespace_pattern constraint. + if params.whitespace_pattern: + if not isinstance(params.whitespace_pattern, str): + raise ValueError("whitespace_pattern must be a string") # Copy from vllm/vllm/entrypoints/openai/api_server.py with custom stat_loggers @asynccontextmanager