Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 89 additions & 94 deletions tests/v1/distributed/test_internal_lb_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytest_asyncio
import requests

from tests.utils import RemoteOpenAIServer
from tests.utils import ROCM_ENV_OVERRIDES, RemoteOpenAIServer
from tests.v1.utils import check_request_balancing
from vllm.platforms import current_platform

Expand All @@ -27,6 +27,84 @@
NUM_NODES = 2


async def _make_completion_request(
client: openai.AsyncOpenAI,
model_name: str,
) -> openai.types.Completion:
"""Make a single completion request and validate the response.

Uses temperature=1.0 to ensure diverse outputs across concurrent
requests for realistic load balancer testing.
"""
completion = await client.completions.create(
model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=1.0,
)

assert completion.id is not None, (
f"Expected non-None completion id. usage={completion.usage!r}"
)
assert completion.choices is not None and len(completion.choices) == 1, (
f"Expected 1 choice, got "
f"{len(completion.choices) if completion.choices else 'None'}"
)

choice = completion.choices[0]
# With temperature=1.0, the model may emit a stop token immediately,
# producing empty text with finish_reason='stop'. This is valid
# model behavior - the test's purpose is load balancing, not output
# quality.
assert choice.finish_reason in ("length", "stop"), (
f"Expected finish_reason 'length' or 'stop', "
f"got {choice.finish_reason!r}. text={choice.text!r}"
)
if choice.finish_reason == "length":
assert len(choice.text) >= 1, (
f"Expected non-empty text with finish_reason='length', got {choice.text!r}"
)

assert completion.usage.prompt_tokens > 0, (
f"Expected positive prompt_tokens, got {completion.usage.prompt_tokens}"
)
assert completion.usage.total_tokens > 0, (
f"Expected positive total_tokens, got {completion.usage.total_tokens}"
)
return completion


async def _run_request_bursts(
client: openai.AsyncOpenAI,
model_name: str,
num_requests: int = 200,
num_bursts: int = 2,
):
"""Send multiple bursts of completion requests and validate all succeed."""
for burst in range(num_bursts):
all_tasks = []
for _ in range(num_requests):
all_tasks.append(
asyncio.create_task(_make_completion_request(client, model_name))
)
await asyncio.sleep(0.01)

results = await asyncio.gather(*all_tasks, return_exceptions=True)
assert len(results) == num_requests, (
f"Burst {burst}: expected {num_requests} results, got {len(results)}"
)

for result in results:
if isinstance(result, BaseException):
raise result

assert all(completion is not None for completion in results), (
f"Burst {burst}: some completions were None"
)

await asyncio.sleep(0.5)


class MultinodeInternalLBServerManager:
"""Manages multi-node data parallel vLLM server instances for internal
load balancer testing using --headless mode."""
Expand Down Expand Up @@ -108,6 +186,7 @@ def start_server(sidx: int, r: int, sargs: list[str]):
auto_port=False,
env_dict={
"VLLM_SERVER_DEV_MODE": "1",
**ROCM_ENV_OVERRIDES,
current_platform.device_control_env_var: ",".join(
str(current_platform.device_id_to_physical_device_id(i))
for i in range(r, r + gpus_per_node)
Expand Down Expand Up @@ -229,6 +308,7 @@ def start_api_server():
auto_port=False,
env_dict={
"VLLM_SERVER_DEV_MODE": "1",
**ROCM_ENV_OVERRIDES,
# No GPUs needed for API-only server
},
)
Expand All @@ -249,10 +329,11 @@ def start_engines_server():
engines_server_args,
auto_port=False,
env_dict={
**ROCM_ENV_OVERRIDES,
current_platform.device_control_env_var: ",".join(
str(current_platform.device_id_to_physical_device_id(i))
for i in range(self.dp_size * self.tp_size)
)
),
},
)
server.__enter__()
Expand Down Expand Up @@ -395,58 +476,15 @@ async def test_multinode_dp_completion(
servers: list[tuple[RemoteOpenAIServer, list[str]]],
model_name: str,
) -> None:
async def make_request():
completion = await client.completions.create(
model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=1.0
)

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1

choice = completion.choices[0]
# The exact number of tokens can vary slightly with temperature=1.0,
# so we check for a reasonable minimum length.
assert len(choice.text) >= 1
# Finish reason might not always be 'length' if the model finishes early
# or due to other reasons, especially with high temperature.
# So, we'll accept 'length' or 'stop'.
assert choice.finish_reason in ("length", "stop")

# Token counts can also vary, so we check they are positive.
assert completion.usage.completion_tokens > 0
assert completion.usage.prompt_tokens > 0
assert completion.usage.total_tokens > 0
return completion

# Test single request
result = await make_request()
result = await _make_completion_request(client, model_name)
assert result is not None
print("Multi-node internal LB handled single completion request successfully")

await asyncio.sleep(0.5)

# Send multiple requests - internal LB should distribute across DP ranks
num_requests = 200
all_tasks = []
for _ in range(num_requests):
all_tasks.append(asyncio.create_task(make_request()))
await asyncio.sleep(0.01)

results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests
assert all(completion is not None for completion in results)

await asyncio.sleep(0.5)

# Second burst of requests
all_tasks = []
for _ in range(num_requests):
all_tasks.append(asyncio.create_task(make_request()))
await asyncio.sleep(0.01)

results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests
assert all(completion is not None for completion in results)
# Send multiple bursts - internal LB should distribute across DP ranks
await _run_request_bursts(client, model_name)

_, server_args = servers[0]
api_server_count = (
Expand Down Expand Up @@ -570,59 +608,16 @@ async def test_api_only_multinode_dp_completion(
) -> None:
"""Test API-only server with all engines on separate headless server."""

async def make_request():
completion = await api_only_client.completions.create(
model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=1.0
)

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1

choice = completion.choices[0]
# The exact number of tokens can vary slightly with temperature=1.0,
# so we check for a reasonable minimum length.
assert len(choice.text) >= 1
# Finish reason might not always be 'length' if the model finishes
# early or due to other reasons, especially with high temperature.
# So, we'll accept 'length' or 'stop'.
assert choice.finish_reason in ("length", "stop")

# Token counts can also vary, so we check they are positive.
assert completion.usage.completion_tokens > 0
assert completion.usage.prompt_tokens > 0
assert completion.usage.total_tokens > 0
return completion

# Test single request
result = await make_request()
result = await _make_completion_request(api_only_client, model_name)
assert result is not None
print("API-only server handled single completion request successfully")

await asyncio.sleep(0.5)

# Send multiple requests - should be distributed across engines on
# Send multiple bursts - should be distributed across engines on
# headless server
num_requests = 200
all_tasks = []
for _ in range(num_requests):
all_tasks.append(asyncio.create_task(make_request()))
await asyncio.sleep(0.01)

results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests
assert all(completion is not None for completion in results)

await asyncio.sleep(0.5)

# Second burst of requests
all_tasks = []
for _ in range(num_requests):
all_tasks.append(asyncio.create_task(make_request()))
await asyncio.sleep(0.01)

results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests
assert all(completion is not None for completion in results)
await _run_request_bursts(api_only_client, model_name)

api_server, api_server_args = api_only_servers[0]
api_server_count = (
Expand Down
Loading