Skip to content

Conversation

@Prowindy
Copy link
Contributor

@Prowindy Prowindy commented Sep 16, 2025

Signed-off-by: Cong Chen [email protected]

Purpose

This PR adds support for the data_parallel_rank parameter in vLLM's OpenAI API, enabling external routers to specify which data parallel rank
should handle a request. This enhancement improves load distribution and cache locality in multi-GPU deployments by allowing DP-aware routing
from router to engine core.

Key functionality:

  • External routers can inject data_parallel_rank into API requests
  • vLLM server routes requests to specific GPU ranks based on the parameter
  • Maintains backward compatibility when parameter is not provided
  • Enables intelligent load balancing and cache-aware routing strategies

Test Plan

  1. End-to-End DP-Aware Routing Test:
    # Start vLLM server with 8 DP ranks
    vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \
      --data-parallel-size 8 --tensor-parallel-size 1 --port 8000
    
    # Start vLLM-router with DP-aware routing (selective ranks 0,2,4,6)
    vllm-router --worker-urls http://0.0.0.0:8000 \
      --dp-aware --policy round_robin --port 30000
    
    # Generate test traffic
    vllm bench serve --num-prompts 100 --port 30000 \
      --endpoint /v1/completions --max-concurrency 32
    
  2. Validation Methods:
    - Added comprehensive logging throughout the request pipeline
    - Implemented rank counters to track request distribution
    - Verified GPU device assignment for each request
    - Tested with multiple routing policies (round_robin, cache_aware)

Test Result

✅ Successful DP-Aware Routing Validation:

  1. Router Discovery: Successfully identified 4 DP-aware workers
    workers: ["http://0.0.0.0:8000@0", "http://0.0.0.0:8000@2", "http://0.0.0.0:8000@4", "http://0.0.0.0:8000@6"]
  2. Request Distribution: Balanced traffic across selected ranks
    Current counters: {0: 2, 2: 3, 4: 3, 6: 2}
  3. End-to-End Pipeline Verification:
    API_SERVER_DEBUG: Received completion request with data_parallel_rank=4
    ROUTING_DEBUG: Using data_parallel_rank=4 directly
    ENGINE_DISPATCH DEBUG: Request dispatched to engine identity 0400 (rank 4)
    GPU_DEVICE_DEBUG: Request received on GPU device, parallel_config.rank=4
  4. Multi-Process Validation: Confirmed different engine processes correctly received assigned requests:
    - EngineCore_0 pid=296341 → Rank 0 requests
    - EngineCore_2 pid=296343 → Rank 2 requests
    - EngineCore_4 pid=296345 → Rank 4 requests
    - EngineCore_6 pid=296347 → Rank 6 requests

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively enables DP-aware routing by introducing the data_parallel_rank parameter in OpenAI API requests. The changes are well-implemented across the protocol, serving layer, and client logic, and include a new test for validation. The core functionality appears solid. My review identified one critical issue in a new Kubernetes manifest file which seems to be a copy-paste error, making the file invalid.

Comment on lines 98 to 195
restartPolicy: Never# vLLM Bench deployment on Node 1 targeting the router - Fixed version
apiVersion: v1
kind: Pod
metadata:
name: vllm-node1-bench
labels:
app: vllm-node1-bench
node-role: load-tester
spec:
containers:
- name: vllm-bench
image: vllm/vllm-openai:latest
command: ["/bin/bash", "-c"]
args:
- |
set -e
echo "🏗️ Setting up vLLM Bench on Node 1..."
echo "🎯 Target: vLLM Router managing 7 servers"

# Wait for router to be ready
echo "⏳ Waiting for vLLM router to be ready..."
router_url="http://vllm-router-all-nodes-service.default.svc.cluster.local:8080"

while ! curl -s "$router_url/health" > /dev/null 2>&1; do
echo "Router at $router_url not ready, waiting 10 seconds..."
sleep 10
done

echo "✅ Router is ready! Router health:"
curl -s "$router_url/health"

echo "🚀 Starting vLLM bench serve load test (1-minute test)..."
echo "📊 Configuration:"
echo " • Target URL: $router_url"
echo " • Model: facebook/opt-350m"
echo " • Requests: 200 (1-minute test)"
echo " • Concurrency: 16"
echo " • Input length: 100 tokens"
echo " • Output length: 50 tokens"

# Run vLLM bench serve command (1-minute test)
vllm bench serve \
--dataset-name random \
--num-prompts 200 \
--model facebook/opt-350m \
--random-input-len 100 \
--random-output-len 50 \
--endpoint /v1/completions \
--base-url "$router_url" \
--max-concurrency 16 \
--save-result \
--ignore-eos \
--served-model-name facebook/opt-350m \
--result-filename /tmp/bench_results.json

echo "✅ Load test completed!"
echo "📊 Results saved to /tmp/bench_results.json"

# Show summary of results using simple shell commands
if [ -f /tmp/bench_results.json ]; then
echo "📈 Load test results available at /tmp/bench_results.json"
echo "File size: $(wc -c < /tmp/bench_results.json) bytes"
fi

# Keep container alive for result analysis
echo "📈 Keeping container alive for result analysis..."
echo "Access results with: kubectl exec vllm-node1-bench -- cat /tmp/bench_results.json"

while true; do
echo "📊 $(date): Load test completed. Router status:"
curl -s "$router_url/health" || echo "Router not responding"
sleep 300 # Check every 5 minutes
done
ports:
- containerPort: 8080
name: http
env:
- name: PYTHONUNBUFFERED
value: "1"
- name: CUDA_VISIBLE_DEVICES
value: "" # No GPU needed for bench client
volumeMounts:
- mountPath: /tmp
name: results-volume
resources:
requests:
cpu: "4"
memory: "8Gi"
limits:
cpu: "8"
memory: "16Gi"
volumes:
- name: results-volume
emptyDir: {}
# Pin to Node 1 specifically
nodeSelector:
kubernetes.io/hostname: ip-192-168-45-140.us-west-2.compute.internal
restartPolicy: Never
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The content of this YAML file is duplicated from line 98 onwards, which makes the file invalid and will cause parsing errors. This appears to be a copy-paste error that needs to be fixed by removing the duplicated section.

  restartPolicy: Never

@robertgshaw2-redhat
Copy link
Collaborator

robertgshaw2-redhat commented Sep 16, 2025

I think this should be implemented as a header rather than in the body similar to how external processes can add request id (

return raw_request.headers.get("X-Request-Id", default)
)

So this could be something like X-data-parallel-rank

This is more canonical with how extra arguments are added to a standard api

@robertgshaw2-redhat robertgshaw2-redhat changed the title Enable DP-aware routing via data_parallel_rank in OpenAI API requests [EP/DP][API Server] Enable DP-aware routing in OpenAI API requests Sep 16, 2025
@Prowindy Prowindy force-pushed the data-parallel-rank-support branch 2 times, most recently from 4b4fb6a to a56748d Compare September 16, 2025 22:00
@Prowindy
Copy link
Contributor Author

@robertgshaw2-redhat The current approach is consistent with how SGLang handles DP-aware requests. With this change to the vLLM server, the sgl-router fork can send requests directly to the vLLM server without requiring any API modifications.

That said, your suggestion is valid—P/D disaggregation also embeds the prefill/decode address in the X-Request-Id. I’ll make a few adjustments to incorporate your recommendation.

@Prowindy Prowindy force-pushed the data-parallel-rank-support branch 2 times, most recently from 0d5207c to a5ef013 Compare October 15, 2025 00:55



@router.get("/get_server_info")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? I think its usually not great to allow external users to see this type of server info

I think that instead we should make sure the server is in DEV mode to expose this

Copy link
Contributor Author

@Prowindy Prowindy Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This design is inherited from sglang API which allows its router to fetch entire server config.

We have a couple of alternative impl options:

  1. Pass data_parallel_size as an input flag to the router. This would require all vllm servers to use the same DP size (or we could further separate it into prefill and decode configs).
  2. Create a /get_data_parallel_size endpoint, allowing us to expose only this specific configuration to callers.

@robertgshaw2-redhat Which of these approaches would you recommend?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's good to keep the discovery part separate (can figure out in future PR). The header changes in this one look good to me now.

@mergify
Copy link

mergify bot commented Oct 28, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Prowindy.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 28, 2025
@Prowindy Prowindy force-pushed the data-parallel-rank-support branch from 66fa31c to 0fe49c6 Compare October 28, 2025 17:09
@mergify
Copy link

mergify bot commented Oct 28, 2025

Documentation preview: https://vllm--24945.org.readthedocs.build/en/24945/

@Prowindy Prowindy force-pushed the data-parallel-rank-support branch from 0fe49c6 to d9cb45e Compare October 28, 2025 17:18
@mergify mergify bot removed the needs-rebase label Oct 28, 2025
@Prowindy Prowindy force-pushed the data-parallel-rank-support branch from 7e9dcb5 to 8a5a86f Compare October 28, 2025 18:06
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Prowindy! LGTM

Looks like there is a commit which needs DCO sign-off, see https://github.com/vllm-project/vllm/pull/24945/checks?check_run_id=53896850372, and some pre-commit errors.

Also I wonder if we could add something to the docs to cover this. Perhaps add a note to the external router section of the data parallel page: https://docs.vllm.ai/en/latest/serving/data_parallel_deployment.html#external-load-balancing

@Prowindy Prowindy force-pushed the data-parallel-rank-support branch from 8a5a86f to 5d00c87 Compare October 28, 2025 20:30
return raw_request.headers.get("X-Request-Id", default)

@staticmethod
def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an attack vector:

  • user specifies X-data-parallel-rank that is > DP size or is negative
  • user specifies X-data-parallel-rank when the server is not using DP

What happens? We need to be careful as this is untrusted user input that could case DOS issues.

Perhaps we should only parse this field if the server operator ops in and log a warning about it. We also need to santize the inputs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The engine performs a sanity check on the input DP rank range before utilizing it:

data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
if data_parallel_rank is not None and not (
0 <= data_parallel_rank < data_parallel_size
):
raise ValueError(
f"data_parallel_rank {data_parallel_rank} "
f"is out of range [0, {data_parallel_size})."
)

Do we want to enhance safety by clamping, or raising exceptions during parsing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Invalid-dp-tests.txt

@robertgshaw2-redhat I tested the vLLM patch by requesting with invalid DP ranks, including negative or large numbers, and non-integer characters. In these cases, the system fallback to a regular request without DP rank info. Would this be our expected behavior?

@Prowindy Prowindy force-pushed the data-parallel-rank-support branch from 5d00c87 to c5b7d26 Compare October 28, 2025 20:38
@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) October 29, 2025 16:30
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 29, 2025
auto-merge was automatically disabled October 29, 2025 22:55

Head branch was pushed to by a user without write access

@Prowindy Prowindy force-pushed the data-parallel-rank-support branch 4 times, most recently from 92eac51 to 074249a Compare October 30, 2025 02:32
Signed-off-by: Cong Chen <[email protected]>

This commit adds support for data parallel rank routing in the OpenAI API
to enable intelligent load balancing across data parallel instances.

Key changes:
- Add _get_data_parallel_rank() method to extract X-data-parallel-rank header
- Pass data_parallel_rank to engine.generate() in chat and completion APIs
- Add /get_server_info endpoint to return DP size for router discovery
- Add test coverage for data_parallel_rank extraction

The X-data-parallel-rank header allows external routers to specify which
data parallel rank should handle a request, enabling DP-aware request
routing and load balancing.
@Prowindy Prowindy force-pushed the data-parallel-rank-support branch from 074249a to 2fc7bbc Compare October 30, 2025 16:42
@WoosukKwon WoosukKwon merged commit a2981c4 into vllm-project:main Oct 30, 2025
44 of 48 checks passed
ZhengHongming888 pushed a commit to ZhengHongming888/vllm that referenced this pull request Nov 8, 2025
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
eldarkurtic pushed a commit to eldarkurtic/vllm that referenced this pull request Nov 12, 2025
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation frontend ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants