Skip to content

[Bugfix] fix DP-aware routing in OpenAI API requests#29002

Merged
njhill merged 10 commits intovllm-project:mainfrom
inkcherry:fix_dp_router
Dec 18, 2025
Merged

[Bugfix] fix DP-aware routing in OpenAI API requests#29002
njhill merged 10 commits intovllm-project:mainfrom
inkcherry:fix_dp_router

Conversation

@inkcherry
Copy link
Contributor

@inkcherry inkcherry commented Nov 19, 2025

Purpose

fix #24945
In add_request, duplicate initialization is skipped, but during the previousself.processor.process_inputs, data_parallel_rank is not initialized. Using -H 'X-data-parallel-rank' to specify the data parallel rank would be invalid in this case., cc @njhill

# Convert Input --> Request.
if isinstance(prompt, EngineCoreRequest):
request = prompt
else:
assert prompt_text is None
logger.warning_once(
"Processor has been moved under OpenAIServing and will "
"be removed from AsyncLLM in v0.13."
)
request = self.processor.process_inputs(
request_id,
prompt,
params,
arrival_time,
lora_request,
tokenization_kwargs,
trace_headers,
priority,
data_parallel_rank,
)

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix an issue with Data Parallelism-aware routing by propagating the data_parallel_rank to the engine's processor. The changes correctly pass the rank through serving_completion.py and serving_engine.py.

However, I've identified a critical issue in serving_engine.py where the signature of _process_inputs is changed in a way that breaks other parts of the code and has an incorrect type hint. I've left a comment with a suggested fix to make the new argument optional, which will prevent runtime errors.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
markmc
markmc previously requested changes Nov 19, 2025
Copy link
Member

@markmc markmc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

/cc @Prowindy

lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a similar fix is required in serving_chat.py?

Also would be good to catch this issue in the test_serving_chat_data_parallel_rank_extraction test

Even better, it would be great to add a similar test for serving_completion !

Copy link
Contributor Author

@inkcherry inkcherry Dec 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@markmc Thanks for the comment. I've added it to serving_chat.py.

I noticed that mock objects won't trigger this error. So I added a test with a real engine for coverage, placed after the test_dp_rank_argument test.

lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be a breaking change for the current API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually not. When unspecified, it defaults to None and uses the default DP load algorithm

@Prowindy
Copy link
Contributor

@inkcherry could you clarify what failure you saw, crash or request failure? Any repro steps will be helpful.

I think data_parallel_rank isn't a must-have for all endpoints. Only the endpoints supported by https://github.com/vllm-project/router would have X-http-header set and data_parallel_rank needed. Missing this should fallback to normal routing mode.

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
@inkcherry
Copy link
Contributor Author

inkcherry commented Dec 17, 2025

@inkcherry could you clarify what failure you saw, crash or request failure? Any repro steps will be helpful.

I think data_parallel_rank isn't a must-have for all endpoints. Only the endpoints supported by https://github.com/vllm-project/router would have X-http-header set and data_parallel_rank needed. Missing this should fallback to normal routing mode.

Thank you for your great work and feedback. I noticed that it does not crash; instead, it fails under certain circumstances (equivalent to not being specified). I have added tests to cover this.

I agree with your observation, I did not add support for new endpoints.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @inkcherry!

Doesn't need to be addressed in this PR but I don't see why we would only want to support this header on the chat and completion endpoints, it could apply similarly to all of the endpoints.

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 18, 2025
@njhill njhill dismissed markmc’s stale review December 18, 2025 17:50

Review comments addressed.

@njhill njhill merged commit 500f26e into vllm-project:main Dec 18, 2025
48 checks passed
@github-project-automation github-project-automation bot moved this to Done in NVIDIA Dec 18, 2025
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Dec 22, 2025
)

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Majid-Taheri pushed a commit to Majid-Taheri/vllm that referenced this pull request Dec 23, 2025
)

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
)

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build frontend nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants