Skip to content
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
9ca44ce
[V1] AsyncLLM data parallel WIP
njhill Feb 26, 2025
3f51611
Handle pausing loop
njhill Feb 27, 2025
d8c591e
More single-node updates
njhill Feb 27, 2025
65e225d
some cleanup
njhill Feb 27, 2025
5ce57b6
fix up utility methods
njhill Feb 27, 2025
a3f1102
revert config check
njhill Feb 27, 2025
a66fb01
fixes
njhill Feb 27, 2025
67672c2
cleanup
njhill Feb 27, 2025
cf52fbf
fixes
njhill Feb 27, 2025
a4ec81b
reconcile with LLMEngine DP in decoupled engine case
njhill Feb 27, 2025
292aa00
minor simplification
njhill Feb 27, 2025
4b62ffd
rework
njhill Feb 28, 2025
407c72e
class refactor
njhill Mar 1, 2025
31bf7ea
fix
njhill Mar 1, 2025
fde51ce
adjust core engine init
njhill Mar 1, 2025
d5a3e68
Merge remote-tracking branch 'refs/remotes/origin/main' into multi-en…
njhill Mar 3, 2025
6d89a1b
fix new typing
njhill Mar 3, 2025
448abd9
fix :facepalm:
njhill Mar 3, 2025
a1e513e
bind socket first
njhill Mar 3, 2025
50cf64c
do you have to let it linger
njhill Mar 3, 2025
f365998
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 3, 2025
b2571f0
add comments
njhill Mar 4, 2025
32c6f24
aggregate stats
njhill Mar 4, 2025
9c30cd7
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 4, 2025
672d07e
Fix test
njhill Mar 4, 2025
dea382b
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 5, 2025
d24a626
fix and minor cleanup
njhill Mar 5, 2025
cd03c80
Add CI test
njhill Mar 6, 2025
f1004b7
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 6, 2025
d3298fa
Some simplification and fixes
njhill Mar 6, 2025
74dde48
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 6, 2025
5fe1b75
address @markmc's stats suggestion
njhill Mar 6, 2025
648659f
address @tms's arg comment
njhill Mar 6, 2025
119d1ec
fix utility method breakage
njhill Mar 6, 2025
55328ee
rename AsyncMPClient output_processor to output_handler
njhill Mar 6, 2025
4f5330e
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 6, 2025
48770ec
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 7, 2025
d229f4d
Fix
njhill Mar 7, 2025
2f91cc4
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 15, 2025
518047a
Remove redundant logic related to removed stats aggregation
njhill Mar 13, 2025
cb2b099
Fixes
njhill Mar 15, 2025
ff1137a
Merge remote-tracking branch 'refs/remotes/origin/main' into multi-en…
njhill Mar 16, 2025
61f4fcb
fix issue from main merge
njhill Mar 16, 2025
44874c2
remove leftover unused field
njhill Mar 17, 2025
66fc582
Fix offline DP compatibility
njhill Mar 17, 2025
7764466
Add timeout to data_parallel.py
njhill Mar 17, 2025
51e8bf0
Merge remote-tracking branch 'refs/remotes/origin/main' into multi-en…
njhill Mar 17, 2025
f692c12
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 19, 2025
47b5e1c
Enable less-frequent all-reduce optimization
njhill Mar 20, 2025
f226139
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 20, 2025
af47920
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 20, 2025
693c521
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 20, 2025
6e131e3
clean distributed shutdown
njhill Mar 20, 2025
d9ac856
address misc loose-ends
njhill Mar 20, 2025
3abbdef
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 21, 2025
b18417e
further tweaks
njhill Mar 21, 2025
56b2b78
Merge remote-tracking branch 'refs/remotes/origin/main' into multi-en…
njhill Mar 25, 2025
05ab310
Additional debug
njhill Mar 25, 2025
5295c34
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 27, 2025
4f897b8
Address review comments on tests
njhill Mar 27, 2025
62f32ed
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 27, 2025
771ccf1
Fix env var fallback
njhill Mar 27, 2025
05a0e83
Fix test supports_v1 check
njhill Mar 27, 2025
bc41b13
Fix yapf :facepalm:
njhill Mar 27, 2025
ccecb42
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,10 @@ steps:
- examples/offline_inference/rlhf.py
- examples/offline_inference/rlhf_colocate.py
- tests/examples/offline_inference/data_parallel.py
- tests/v1/test_async_llm_dp.py
commands:
- VLLM_USE_V1=1 python3 ../examples/offline_inference/data_parallel.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
Expand Down Expand Up @@ -505,7 +507,10 @@ steps:
- vllm/worker/worker.py
- vllm/worker/model_runner.py
- entrypoints/llm/test_collective_rpc.py
- tests/v1/test_async_llm_dp.py
- vllm/v1/engine/
commands:
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- pytest -v -s entrypoints/llm/test_collective_rpc.py
- VLLM_USE_V1=1 torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
Expand Down
8 changes: 4 additions & 4 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,11 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):

core_client: SyncMPClient = client

result = core_client._call_utility("echo", "testarg")
result = core_client.call_utility("echo", "testarg")
assert result == "testarg"

with pytest.raises(Exception) as e_info:
core_client._call_utility("echo", None, "help!")
core_client.call_utility("echo", None, "help!")

assert str(e_info.value) == "Call to echo method failed: help!"

Expand Down Expand Up @@ -236,10 +236,10 @@ async def test_engine_core_client_asyncio(monkeypatch):

core_client: AsyncMPClient = client

result = await core_client._call_utility_async("echo", "testarg")
result = await core_client.call_utility_async("echo", "testarg")
assert result == "testarg"

with pytest.raises(Exception) as e_info:
await core_client._call_utility_async("echo", None, "help!")
await core_client.call_utility_async("echo", None, "help!")

assert str(e_info.value) == "Call to echo method failed: help!"
102 changes: 102 additions & 0 deletions tests/v1/test_async_llm_dp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# SPDX-License-Identifier: Apache-2.0

import asyncio
import os
from contextlib import ExitStack
from typing import Optional

import pytest

from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import DPAsyncMPClient

if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
allow_module_level=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if DP works on TPU or AMD GPUs, but modify this reason string since V1 works there at least experimentally?

vllm/vllm/engine/arg_utils.py

Lines 1669 to 1675 in d0cfec7

# No support for device type other than CUDA, AMD (experiemntal) or
# TPU (experimental) so far.
if not (current_platform.is_cuda_alike() or current_platform.is_tpu()):
_raise_or_fallback(
feature_name=f"device type={current_platform.device_type}",
recommend_to_remove=False)
return False

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could actually use supports_v1 now that this PR has landed (probably only want to turn tests on for CUDA and RoCM though)

#15417



async def generate(engine: AsyncLLM,
request_id: str,
prompt: PromptType,
output_kind: RequestOutputKind,
max_tokens: int,
prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
# Ensure generate doesn't complete too fast for cancellation test.
await asyncio.sleep(0.2)

count = 0
sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True,
output_kind=output_kind,
temperature=0,
prompt_logprobs=prompt_logprobs)
async for out in engine.generate(request_id=request_id,
prompt=prompt,
sampling_params=sampling_params):

num_tokens = len(out.outputs[0].token_ids)
if output_kind == RequestOutputKind.DELTA:
count += num_tokens
else:
count = num_tokens

await asyncio.sleep(0.)

return count, request_id


@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_load(monkeypatch, output_kind: RequestOutputKind):
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this now that V1 is on by default?


engine_args = AsyncEngineArgs(
model="ibm-research/PowerMoE-3b",
enforce_eager=True,
disable_log_requests=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=int(os.getenv("DP_SIZE", 2)),
)

prompt = "This is a test of data parallel"

engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)

NUM_REQUESTS = 100
NUM_EXPECTED_TOKENS = 10

request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

# Create concurrent requests.
tasks = []
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(engine, request_id, prompt, output_kind,
NUM_EXPECTED_TOKENS)))

# Confirm that we got all the EXPECTED tokens from the requests.
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION)
for task in pending:
task.cancel()
for task in done:
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
f"{request_id} generated {num_generated_tokens} but "
f"expected {NUM_EXPECTED_TOKENS}")

assert not engine.output_processor.has_unfinished_requests()

# testing internals here which may break
core_client: DPAsyncMPClient = engine.engine_core
assert core_client.num_engines_running == 0
assert not core_client.reqs_in_flight
16 changes: 11 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, random_uuid, resolve_obj_by_qualname)
get_cpu_memory, get_open_port, random_uuid,
resolve_obj_by_qualname)

if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
Expand Down Expand Up @@ -1435,10 +1436,15 @@ def __post_init__(self) -> None:
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size

self.data_parallel_size = envs.VLLM_DP_SIZE
self.data_parallel_rank = envs.VLLM_DP_RANK
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
if self.data_parallel_size > 1:
self.data_parallel_master_port = get_open_port()
# TODO multi-node
else:
self.data_parallel_size = envs.VLLM_DP_SIZE
self.data_parallel_rank = envs.VLLM_DP_RANK
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT

self.world_size_across_dp = self.world_size * self.data_parallel_size

if self.distributed_executor_backend == "external_launcher":
Expand Down
10 changes: 10 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class EngineArgs:
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
data_parallel_size: int = 1
enable_expert_parallel: bool = False
max_parallel_loading_workers: Optional[int] = None
block_size: Optional[int] = None
Expand Down Expand Up @@ -441,6 +442,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=int,
default=EngineArgs.tensor_parallel_size,
help='Number of tensor parallel replicas.')
parser.add_argument('--data-parallel-size',
'-dp',
type=int,
default=EngineArgs.data_parallel_size,
help='Number of data parallel replicas. '
'MoE layers will be sharded according to the '
'product of the tensor-parallel-size and '
'data-parallel-size.')
parser.add_argument(
'--enable-expert-parallel',
action='store_true',
Expand Down Expand Up @@ -1213,6 +1222,7 @@ def create_engine_config(self,
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
data_parallel_size=self.data_parallel_size,
enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
Expand Down
14 changes: 9 additions & 5 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def get_open_port() -> int:
dp_port = envs.VLLM_DP_MASTER_PORT
while True:
port = _get_open_port()
if port >= dp_port and port < dp_port + 10:
if dp_port <= port < dp_port + 10:
continue
return port
return _get_open_port()
Expand Down Expand Up @@ -2134,19 +2134,23 @@ def make_zmq_socket(
if socket_type == zmq.constants.PULL:
socket.setsockopt(zmq.constants.RCVHWM, 0)
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
socket.connect(path)
socket.bind(path)
elif socket_type == zmq.constants.PUSH:
socket.setsockopt(zmq.constants.SNDHWM, 0)
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
socket.bind(path)
socket.connect(path)
else:
raise ValueError(f"Unknown Socket Type: {socket_type}")

return socket


@contextlib.contextmanager
def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
def zmq_socket_ctx(
path: str,
socket_type: Any,
linger: int = 0,
) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""

ctx = zmq.Context() # type: ignore[attr-defined]
Expand All @@ -2157,7 +2161,7 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
logger.debug("Got Keyboard Interrupt.")

finally:
ctx.destroy(linger=0)
ctx.destroy(linger=linger)


def _check_multiproc_method():
Expand Down
12 changes: 9 additions & 3 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ def __init__(
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
log_stats: bool,
include_finished_set: bool = False,
log_stats: bool = False,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.lora_config = lora_config
self.speculative_config = speculative_config
self.include_finished_set = include_finished_set
self.log_stats = log_stats

# Scheduling constraints.
Expand Down Expand Up @@ -583,10 +585,14 @@ def update_from_output(
new_running.append(request)

self.running = new_running
return EngineCoreOutputs(
engine_core_outputs = EngineCoreOutputs(
outputs=outputs,
scheduler_stats=self.make_stats(),
)
if self.include_finished_set:
engine_core_outputs.finished_requests = (
scheduler_output.finished_req_ids)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these different from the outputs with a non-None finish_reason ? If so, we're duplicating this information because we don't want the core client to loop over all outputs (we do that at the AsyncLLM level)?

Looks like finished_req_ids is intended for the model runner? If we extend their use, we should update the comment describing their purpose

        # The request IDs that are finished in between the previous and the                                                                                                           
        # current steps. This is used to notify the workers about the finished                                                                                                        
        # requests so that they can free the cached states for those requests.                                                                                                        
        # This is flushed at the end of each scheduling step.                                                                                                                         
        self.finished_req_ids: set[str] = set()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these different from the outputs with a non-None finish_reason ? If so, we're duplicating this information because we don't want the core client to loop over all outputs (we do that at the AsyncLLM level)?

That's correct, and these finished req ids will be "sparse" i.e. only one per request, so the overhead should be minimal. But the alternative would mean looping through all requests in the batch on every step (in the client).

return engine_core_outputs

def _check_stop(self, request: Request) -> bool:
if (request.num_tokens >= self.max_model_len
Expand Down Expand Up @@ -655,7 +661,7 @@ def get_num_unfinished_requests(self) -> int:
return len(self.waiting) + len(self.running)

def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0
return len(self.running) > 0 or len(self.waiting) > 0

def get_num_unscheduled_requests(self) -> int:
"""Number of requests that are not being processed by the executor."""
Expand Down
12 changes: 11 additions & 1 deletion vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ class EngineCoreOutputs(
timestamp: float = 0.0

utility_output: Optional[UtilityOutput] = None
finished_requests: Optional[set[str]] = None

# In DP case, used to signal that the engine is paused.
engine_paused: bool = False

# Set to False to indicate stats should be accumulated rather than
# recorded, when there are remaining outputs from other engines
# still to come for this iteration.
final_outputs_for_step: bool = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only ever set to false on the client side right? Even though this is a struct for core->client comms?

It's basically a signal from DPAsyncMPClient to AsyncLLM output handler?

(Seems nasty, but not 100% sure I've got that right)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you're right and I also don't like it and had been thinking about alternatives. For this first PR I had been trying to avoid changing too much.

It's a struct from core->client comm but also what's returned from the client. Agree that we could / probably should separate these things. But TBH I'm not set on this mechanism for aggregating the metrics so was expecting it may be reworked anyhow.


def __post_init__(self):
if self.timestamp == 0.0:
Expand All @@ -147,4 +156,5 @@ class EngineCoreRequestType(enum.Enum):
"""
ADD = b'\x00'
ABORT = b'\x01'
UTILITY = b'\x02'
START_DP = b'\x02'
UTILITY = b'\x03'
Loading