diff --git a/tests/entrypoints/openai/test_chat_error.py b/tests/entrypoints/openai/test_chat_error.py index 7d84be218447..57841956e455 100644 --- a/tests/entrypoints/openai/test_chat_error.py +++ b/tests/entrypoints/openai/test_chat_error.py @@ -153,6 +153,19 @@ async def mock_generate(*args, **kwargs): assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR +def test_json_schema_response_format_missing_schema(): + """When response_format type is 'json_schema' but the json_schema field + is not provided, request construction should raise a validation error + so the API returns 400 instead of 500.""" + with pytest.raises(Exception, match="json_schema.*must be provided"): + ChatCompletionRequest( + model=MODEL_NAME, + prompt="Test prompt", + max_tokens=10, + response_format={"type": "json_schema"}, + ) + + @pytest.mark.asyncio async def test_chat_error_stream(): """test finish_reason='error' returns 500 InternalServerError (streaming)""" diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index 1bf0de53fa3c..16767b20bc14 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -447,7 +447,7 @@ def to_sampling_params( structured_outputs_kwargs["json"] = json_schema.json_schema elif response_format.type == "structural_tag": structural_tag = response_format - assert structural_tag is not None and isinstance( + assert isinstance( structural_tag, ( LegacyStructuralTagResponseFormat, @@ -502,6 +502,34 @@ def to_sampling_params( skip_clone=True, # Created fresh per request, safe to skip clone ) + @model_validator(mode="before") + @classmethod + def validate_response_format(cls, data): + response_format = data.get("response_format") + if response_format is None: + return data + + rf_type = ( + response_format.get("type") + if isinstance(response_format, dict) + else getattr(response_format, "type", None) + ) + + if rf_type == "json_schema": + json_schema = ( + response_format.get("json_schema") + if isinstance(response_format, dict) + else getattr(response_format, "json_schema", None) + ) + if json_schema is None: + raise VLLMValidationError( + "When response_format type is 'json_schema', the " + "'json_schema' field must be provided.", + parameter="response_format", + ) + + return data + @model_validator(mode="before") @classmethod def validate_stream_options(cls, data):