Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/entrypoints/openai/test_chat_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ async def _fake_process_inputs(
lora_request,
trace_headers,
priority,
data_parallel_rank,
):
return dict(engine_prompt), {}

Expand Down
1 change: 1 addition & 0 deletions tests/entrypoints/openai/test_completion_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ async def _fake_process_inputs(
lora_request,
trace_headers,
priority,
data_parallel_rank,
):
return dict(engine_prompt), {}

Expand Down
1 change: 1 addition & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ async def _fake_process_inputs(
lora_request,
trace_headers,
priority,
data_parallel_rank,
):
return dict(engine_prompt), {}

Expand Down
61 changes: 61 additions & 0 deletions tests/v1/engine/test_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
from vllm.assets.image import ImageAsset
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ErrorResponse,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.inputs import PromptType
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
Expand Down Expand Up @@ -484,6 +491,60 @@ async def test_dp_rank_argument():
pass


@pytest.mark.asyncio(scope="module")
async def test_header_dp_rank_argument():
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)

MODEL_NAME = "test-model"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]

# Create models first
models = OpenAIServingModels(
engine_client=engine,
base_model_paths=BASE_MODEL_PATHS,
)

# Create serving chat instance
serving_chat = OpenAIServingChat(
engine_client=engine,
models=models,
response_role="assistant",
chat_template=None,
chat_template_content_format="auto",
request_logger=None,
)
# Create a chat completion request
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": TEXT_PROMPT}],
max_tokens=100,
temperature=1.0,
seed=33,
)
# Test 1: Valid DP rank (0)
mock_raw_request = MagicMock()
mock_raw_request.headers = {"X-data-parallel-rank": "0"}
mock_raw_request.state = MagicMock()

# Should succeed with valid rank
response = await serving_chat.create_chat_completion(req, mock_raw_request)
assert isinstance(response, ChatCompletionResponse), (
"Expected a ChatCompletionResponse for valid DP rank"
)

# Test 2: Out-of-range DP rank (1)
mock_raw_request.headers = {"X-data-parallel-rank": "1"}

# should return ErrorResponse for out-of-range rank
response2 = await serving_chat.create_chat_completion(req, mock_raw_request)
assert isinstance(response2, ErrorResponse), (
"Expected an ErrorResponse for out-of-range DP rank"
)


@pytest.mark.asyncio
async def test_check_health():
"""Test that check_health returns normally for healthy engine
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ async def create_chat_completion(
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
)

generator = self.engine_client.generate(
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ async def create_completion(
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a similar fix is required in serving_chat.py?

Also would be good to catch this issue in the test_serving_chat_data_parallel_rank_extraction test

Even better, it would be great to add a similar test for serving_completion !

Copy link
Copy Markdown
Contributor Author

@inkcherry inkcherry Dec 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@markmc Thanks for the comment. I've added it to serving_chat.py.

I noticed that mock objects won't trigger this error. So I added a test with a real engine for coverage, placed after the test_dp_rank_argument test.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be a breaking change for the current API?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually not. When unspecified, it defaults to None and uses the default DP load algorithm

)

generator = self.engine_client.generate(
Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,7 @@ async def _process_inputs(
lora_request: LoRARequest | None,
trace_headers: Mapping[str, str] | None,
priority: int,
data_parallel_rank: int | None = None,
) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""Use the Processor to process inputs for AsyncLLM."""
tokenization_kwargs: dict[str, Any] = {}
Expand All @@ -1246,6 +1247,7 @@ async def _process_inputs(
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
return engine_request, tokenization_kwargs

Expand Down