Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
53223f2
[ROCm][CI] Stabilize 400 error return code for invalid schema inputs
AndreasKaratzas May 18, 2026
f5b07f2
Merge remote-tracking branch 'origin/main' into akaratza_stabilize_en…
AndreasKaratzas May 19, 2026
7128304
Merge remote-tracking branch 'origin/main' into akaratza_stabilize_en…
AndreasKaratzas May 19, 2026
01e5eac
[ROCm][CI] Stabilize 400 error return code for invalid schema inputs
AndreasKaratzas May 19, 2026
26d4c22
Merge remote-tracking branch 'origin/main' into akaratza_stabilize_en…
AndreasKaratzas May 19, 2026
41abe07
[ROCm][CI] Stabilize 400 error return code for invalid schema inputs
AndreasKaratzas May 19, 2026
de4e34b
Merge remote-tracking branch 'origin/main' into akaratza_stabilize_en…
AndreasKaratzas May 20, 2026
d596a9f
[ROCm][CI] Stabilize 400 error return code for invalid schema inputs
AndreasKaratzas May 20, 2026
1082646
Merge remote-tracking branch 'origin/main' into akaratza_stabilize_en…
AndreasKaratzas May 20, 2026
80e73ce
[ROCm][CI] Stabilize 400 error return code for invalid schema inputs
AndreasKaratzas May 20, 2026
fccb373
Merge remote-tracking branch 'origin/main' into akaratza_stabilize_en…
AndreasKaratzas May 21, 2026
d86dfd2
Merge remote-tracking branch 'origin/main' into akaratza_stabilize_en…
AndreasKaratzas May 23, 2026
0f205a9
Reverting bad merge
AndreasKaratzas May 23, 2026
9269a1e
Merge remote-tracking branch 'origin/main' into akaratza_stabilize_en…
AndreasKaratzas May 24, 2026
ef7bea3
Honor shutdown timeouts and update weight-transfer mocks
AndreasKaratzas May 24, 2026
5f13356
Removing eager check
AndreasKaratzas May 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions tests/entrypoints/openai/completion/test_shutdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
_PROCESS_EXIT_TIMEOUT = 15
_SHUTDOWN_DETECTION_TIMEOUT = 10
_CHILD_CLEANUP_TIMEOUT = 10
_INFLIGHT_REQUEST_START_TIMEOUT = 5
_INFLIGHT_REQUEST_POLL_INTERVAL = 0.1
_ABORT_CLIENT_TIMEOUT = 3


def _get_child_pids(parent_pid: int) -> list[int]:
Expand Down Expand Up @@ -71,6 +74,7 @@ class ShutdownState:
requests_after_sigterm: int = 0
aborted_requests: int = 0
connection_errors: int = 0
inflight_requests: int = 0
stop_requesting: bool = False
errors: list[str] = field(default_factory=list)

Expand All @@ -86,6 +90,7 @@ async def _concurrent_request_loop(
async def single_request():
while not state.stop_requesting:
try:
state.inflight_requests += 1
response = await client.completions.create(
model=MODEL_NAME,
prompt="Write a story: ",
Expand All @@ -110,6 +115,8 @@ async def single_request():
except Exception as e:
state.errors.append(f"Unexpected error: {e}")
break
finally:
state.inflight_requests -= 1
await asyncio.sleep(0.01)

tasks = [asyncio.create_task(single_request()) for _ in range(concurrency)]
Expand Down Expand Up @@ -392,7 +399,7 @@ async def test_abort_timeout_fails_inflight_requests():
]

with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server:
client = remote_server.get_async_client()
client = remote_server.get_async_client(timeout=_ABORT_CLIENT_TIMEOUT)
proc = remote_server.proc
child_pids = _get_child_pids(proc.pid)

Expand All @@ -403,7 +410,10 @@ async def test_abort_timeout_fails_inflight_requests():
_concurrent_request_loop(client, state, sigterm_sent, concurrency=10)
)

await asyncio.sleep(0.5)
deadline = time.time() + _INFLIGHT_REQUEST_START_TIMEOUT
while state.inflight_requests == 0 and time.time() < deadline:
await asyncio.sleep(_INFLIGHT_REQUEST_POLL_INTERVAL)
assert state.inflight_requests > 0

proc.send_signal(signal.SIGTERM)
sigterm_sent.set()
Expand Down
4 changes: 3 additions & 1 deletion tests/entrypoints/openai/test_openai_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def no_invalid_types(case: schemathesis.models.Case):
# the default filtered-vs-good ratio. The filter is intentional, so
# suppress the health check rather than drop the filter — dropping it
# exposes pre-existing server bugs out of scope here.
suppress_health_check=[HealthCheck.filter_too_much],
# The same nested schema can also trip Hypothesis' entropy budget while
# generating large-but-valid request bodies before vLLM is called.
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
)
def test_openapi_stateless(case: Case):
key = (
Expand Down
9 changes: 9 additions & 0 deletions tests/entrypoints/serve/disagg/test_generate_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,16 @@ class MockParallelConfig:
_api_process_rank: int = 0


@dataclass
class MockSchedulerConfig:
max_num_seqs: int = 128


@dataclass
class MockVllmConfig:
model_config: MockModelConfig
parallel_config: MockParallelConfig
scheduler_config: MockSchedulerConfig = field(default_factory=MockSchedulerConfig)


def _build_renderer(model_config: MockModelConfig):
Expand Down Expand Up @@ -149,6 +155,9 @@ def _mock_engine() -> MagicMock:
engine = MagicMock(spec=AsyncLLM)
engine.errored = False
engine.model_config = MockModelConfig()
engine.vllm_config = MockVllmConfig(
engine.model_config, parallel_config=MockParallelConfig()
)
engine.input_processor = MagicMock()
engine.renderer = _build_renderer(engine.model_config)
return engine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ class MockWeightTransferEngine(WeightTransferEngine[MockInitInfo, MockUpdateInfo
last_init_info: MockInitInfo | None = None
last_update_info: MockUpdateInfo | None = None

def __init__(self, config, parallel_config):
super().__init__(config, parallel_config)
def __init__(self, config, parallel_config, model):
super().__init__(config, parallel_config, model)
# Reset tracking on init
MockWeightTransferEngine.init_transfer_engine_called = False
MockWeightTransferEngine.receive_weights_called = False
Expand Down Expand Up @@ -95,9 +95,9 @@ def trainer_send_weights(self, *args, **kwargs):
pass


def mock_create_engine(config, parallel_config):
def mock_create_engine(config, parallel_config, model):
"""Mock factory function that returns our mock engine."""
return MockWeightTransferEngine(config, parallel_config)
return MockWeightTransferEngine(config, parallel_config, model)


# --- Tests ---
Expand Down
22 changes: 19 additions & 3 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1818,12 +1818,28 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
continue

for item in tool_calls:
if not isinstance(item, dict):
raise VLLMValidationError(
"assistant tool_calls entries must be objects.",
parameter="tool_calls",
)

function = item.get("function")
if item.get("type", "function") != "function" or not isinstance(
function, dict
):
raise VLLMValidationError(
"chat completions only support assistant tool_calls "
"of type 'function'.",
parameter="tool_calls",
)

# if arguments is None or empty string, set to {}
if content := item["function"].get("arguments"):
if content := function.get("arguments"):
if not isinstance(content, (dict, list)):
item["function"]["arguments"] = json.loads(content)
function["arguments"] = json.loads(content)
else:
item["function"]["arguments"] = {}
function["arguments"] = {}


def parse_chat_messages(
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ async def build_async_engine_client_from_engine_args(
yield async_llm
finally:
if async_llm:
async_llm.shutdown()
async_llm.shutdown(timeout=vllm_config.shutdown_timeout)


def build_app(
Expand Down
8 changes: 6 additions & 2 deletions vllm/entrypoints/openai/chat_completion/batch_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,6 @@ async def chat_completion_full_generator_batch(
``check_batch_mode`` validator, so neither needs to be handled here.
"""
created_time = int(time.time())
role = self.get_chat_request_role(request) # type: ignore[arg-type]

final_results: dict[int, RequestOutput] = {}
try:
async for prompt_idx, res in merge_async_iterators(*generators):
Expand Down Expand Up @@ -275,6 +273,12 @@ async def chat_completion_full_generator_batch(
reasoning = None
content = output.text

role = (
self.response_role
if request.add_generation_prompt
else request.messages[prompt_idx][-1]["role"]
)

message = ChatMessage(role=role, reasoning=reasoning, content=content)

if request.echo:
Expand Down
4 changes: 3 additions & 1 deletion vllm/entrypoints/openai/chat_completion/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,9 @@ class BatchChatCompletionRequest(OpenAIBaseModel):
- The ``n`` parameter must be 1 (or omitted).
"""

messages: list[list[ChatCompletionMessageParam]] = Field(..., min_length=1)
messages: list[Annotated[list[ChatCompletionMessageParam], Field(min_length=1)]] = (
Field(..., min_length=1)
)
model: str | None = None

# Shared sampling / generation fields — mirror ChatCompletionRequest.
Expand Down
10 changes: 10 additions & 0 deletions vllm/entrypoints/openai/engine/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from http import HTTPStatus
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar

import msgspec
from fastapi import Request
from openai.types.responses import ToolChoiceFunction
from pydantic import ConfigDict, TypeAdapter, ValidationError
Expand Down Expand Up @@ -426,6 +427,15 @@ async def _with_kv_transfer_rejection_cleanup(
"""Wrap a `create_*` coroutine so that, if it raises or returns an
ErrorResponse (i.e. the request never reached the engine), the KV
connector is notified to free any pinned remote-prefill blocks."""
if request.kv_transfer_params is not None:
try:
msgspec.msgpack.encode(request.kv_transfer_params)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this cause additional overhead? From my understanding this is called even if no error occurs in the request

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed it. The cleanup wrapper now only handles the KV rejection cleanup path again. So, initially, the "eager" msgpack encode was meant as a quick validation guard, but indeed it is not the right place for that check cause it also runs on successful requests.

except (OverflowError, TypeError, ValueError) as e:
close = getattr(awaitable, "close", None)
if close is not None:
close()
return self.create_error_response(e) # type: ignore[return-value]

kv_transfer_params = self.has_kv_connector and request.kv_transfer_params
if not kv_transfer_params or not kv_transfer_params.get("do_remote_prefill"):
return await awaitable
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/serve/disagg/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class GenerateRequest(BaseModel):
"through out the inference process and return in response."
),
)
token_ids: list[int]
token_ids: list[int] = Field(min_length=1)
"""The token ids to generate text from."""

@field_validator("token_ids")
Expand Down
14 changes: 13 additions & 1 deletion vllm/entrypoints/serve/disagg/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections.abc import AsyncGenerator
from collections.abc import Sequence as GenericSequence

import msgspec
import numpy as np
import pybase64 as base64
from fastapi import Request
Expand Down Expand Up @@ -125,6 +126,18 @@ async def serve_tokens(
if raw_request:
raw_request.state.request_metadata = request_metadata

sampling_params = request.sampling_params
max_num_seqs = self.engine_client.vllm_config.scheduler_config.max_num_seqs
if sampling_params.n > max_num_seqs:
return self.create_error_response(
f"sampling_params.n must be at most the server's max_num_seqs "
f"({max_num_seqs}), got {sampling_params.n}."
)
try:
msgspec.msgpack.encode(sampling_params)
except (OverflowError, TypeError, ValueError) as e:
return self.create_error_response(e)

engine_input: EngineInput
if features := request.features:
# Convert PlaceholderRangeInfo → PlaceholderRange per modality.
Expand Down Expand Up @@ -164,7 +177,6 @@ async def serve_tokens(

# Schedule the request and get the result generator.
result_generator: AsyncGenerator[RequestOutput, None] | None = None
sampling_params = request.sampling_params

# Apply server-side ``max_tokens`` defaulting when the client did
# not set it, matching the OpenAI-compat endpoints. ``SamplingParams``
Expand Down
7 changes: 3 additions & 4 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,10 +474,9 @@ def shutdown(procs: list[BaseProcess], timeout: float | None = None) -> None:
timeout: Maximum time in seconds to wait for graceful shutdown
"""
if timeout is None:
timeout = 0.0

# Allow at least 5 seconds for remaining procs to terminate.
timeout = max(timeout, 5.0)
# Keep a small grace period for best-effort cleanup paths that do not
# have a user-configured shutdown timeout.
timeout = 5.0

# Shutdown the process.
for proc in procs:
Expand Down
Loading