-
-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[BugFix] Make PD work with Ray #21072
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
951096e
b629b86
c0f9c92
80d861e
1c63f8e
913cd52
9d4c583
ac43f24
2013ef6
7e4bf72
c6c48c5
5e97ce6
f04be9f
23409ae
c70c5c1
ee04a92
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,28 +1,12 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding the tests and bug fix from #21048
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's now been merged to main so can rebase. |
||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import threading | ||
| from collections import defaultdict | ||
| from concurrent.futures import Future | ||
| from typing import Optional | ||
|
|
||
| from vllm.v1.executor.multiproc_executor import MultiprocExecutor | ||
| from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator | ||
| from vllm.v1.outputs import ModelRunnerOutput | ||
|
|
||
|
|
||
| class DummyMultiprocExecutor(MultiprocExecutor): | ||
|
|
||
| def __init__(self, output_rank, world_size): | ||
| # Manually initialize minimal required fields | ||
| self.output_rank = output_rank | ||
| self.world_size = world_size | ||
| self._send_remaining_count = defaultdict[str, | ||
| int](lambda: self.world_size) | ||
| self._recv_remaining_count = defaultdict[str, | ||
| int](lambda: self.world_size) | ||
| self.io_thread_pool = None | ||
| self.shutdown_event = threading.Event() | ||
|
|
||
|
|
||
| class DummyModelRunnerOutput(ModelRunnerOutput): | ||
|
|
||
| def __init__(self, | ||
|
|
@@ -33,14 +17,14 @@ def __init__(self, | |
|
|
||
|
|
||
| def test_aggregate_workers_output(): | ||
| executor = DummyMultiprocExecutor(output_rank=0, world_size=2) | ||
| aggregator = KVOutputAggregator(world_size=2) | ||
|
|
||
| output1 = DummyModelRunnerOutput(finished_sending={'req1'}, | ||
| finished_recving={'req2'}) | ||
| output2 = DummyModelRunnerOutput(finished_sending=None, | ||
| finished_recving=None) | ||
|
|
||
| aggregated = executor._aggregate_workers_output([output1, output2]) | ||
| aggregated = aggregator.aggregate([output1, output2]) | ||
|
|
||
| assert aggregated is output1 | ||
| assert aggregated.finished_sending is None | ||
|
|
@@ -51,7 +35,7 @@ def test_aggregate_workers_output(): | |
| output2 = DummyModelRunnerOutput(finished_sending={'req1'}, | ||
| finished_recving=None) | ||
|
|
||
| aggregated = executor._aggregate_workers_output([output1, output2]) | ||
| aggregated = aggregator.aggregate([output1, output2]) | ||
|
|
||
| assert aggregated is output1 | ||
| assert aggregated.finished_sending == {'req1'} | ||
|
|
@@ -62,20 +46,19 @@ def test_aggregate_workers_output(): | |
| output2 = DummyModelRunnerOutput(finished_sending={'req1'}, | ||
| finished_recving={'req2'}) | ||
|
|
||
| aggregated = executor._aggregate_workers_output([output1, output2]) | ||
| aggregated = aggregator.aggregate([output1, output2]) | ||
|
|
||
| assert aggregated is output1 | ||
| assert aggregated.finished_sending is None | ||
| assert aggregated.finished_recving == {'req2'} | ||
|
|
||
|
|
||
| def test_async_aggregate_workers_output(): | ||
| executor = DummyMultiprocExecutor(output_rank=0, world_size=2) | ||
| aggregator = KVOutputAggregator(world_size=2) | ||
|
|
||
| future1: Future[DummyModelRunnerOutput] = Future() | ||
| future2: Future[DummyModelRunnerOutput] = Future() | ||
| result_future = executor._async_aggregate_workers_output( | ||
| [future1, future2]) | ||
| result_future = aggregator.async_aggregate([future1, future2]) | ||
|
|
||
| output1 = DummyModelRunnerOutput(finished_sending={'req1'}, | ||
| finished_recving={'req2'}) | ||
|
|
@@ -92,8 +75,7 @@ def test_async_aggregate_workers_output(): | |
|
|
||
| future1 = Future() | ||
| future2 = Future() | ||
| result_future = executor._async_aggregate_workers_output( | ||
| [future1, future2]) | ||
| result_future = aggregator.async_aggregate([future1, future2]) | ||
|
|
||
| output1 = DummyModelRunnerOutput(finished_sending=None, | ||
| finished_recving=None) | ||
|
|
@@ -110,8 +92,7 @@ def test_async_aggregate_workers_output(): | |
|
|
||
| future1 = Future() | ||
| future2 = Future() | ||
| result_future = executor._async_aggregate_workers_output( | ||
| [future1, future2]) | ||
| result_future = aggregator.async_aggregate([future1, future2]) | ||
|
|
||
| output1 = DummyModelRunnerOutput(finished_sending=None, | ||
| finished_recving=None) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,12 +3,18 @@ | |
| """ | ||
| KV cache helper for store. | ||
| """ | ||
| from collections import defaultdict | ||
| from collections.abc import Sequence | ||
| from concurrent.futures import CancelledError, Future | ||
| from typing import Optional, cast | ||
|
|
||
| import torch | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm import _custom_ops as ops | ||
| from vllm.config import VllmConfig, get_current_vllm_config | ||
| from vllm.logger import init_logger | ||
| from vllm.v1.outputs import ModelRunnerOutput | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
@@ -107,3 +113,87 @@ def get_kv_connector_cache_layout(): | |
| "layout to HND for better xfer performance.") | ||
| return "HND" | ||
| return "NHD" | ||
|
|
||
|
|
||
| class KVOutputAggregator: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This utility class LGTM |
||
| """Utility class to aggregate the output of all workers into a single | ||
| output corresponding to Rank 0 for scheduler.""" | ||
|
|
||
| def __init__(self, world_size: int): | ||
| # Complete transfer tracker. Used by to track finished requests | ||
| # [req_id -> n_finished_workers] | ||
| self._recv_remaining_count = defaultdict[str, int](lambda: world_size) | ||
| self._send_remaining_count = defaultdict[str, int](lambda: world_size) | ||
|
|
||
| def aggregate(self, | ||
| outputs: list[ModelRunnerOutput], | ||
| output_rank: int = 0) -> ModelRunnerOutput: | ||
| # aggregate finished_sending, finished_recving from all workers | ||
|
|
||
| def update_finished_set(req_ids: Optional[set[str]], | ||
kouroshHakha marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| remaining_count_dict: dict[str, int], | ||
| finished_set: set[str]) -> None: | ||
| for req_id in req_ids or (): | ||
| new_count = remaining_count_dict[req_id] - 1 | ||
| if new_count == 0: | ||
| finished_set.add(req_id) | ||
| del remaining_count_dict[req_id] | ||
| else: | ||
| remaining_count_dict[req_id] = new_count | ||
|
|
||
| finished_sending = set[str]() | ||
| finished_recving = set[str]() | ||
| for output in outputs: | ||
| update_finished_set(output.finished_sending, | ||
| self._send_remaining_count, finished_sending) | ||
| update_finished_set(output.finished_recving, | ||
| self._recv_remaining_count, finished_recving) | ||
|
|
||
| # select output of the worker specified by output_rank | ||
| output = outputs[output_rank] | ||
|
|
||
| # set the aggregated finished_sending / finished_recving | ||
| # if output.finished_sending/recving is not empty, but the other ranks | ||
| # still have unfinished send/recv, we want to set the aggregated | ||
| # finished_sending/recving to None until all ranks have finished | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason why it's set to None instead of empty set ?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's imposed by higher level logic. This part of the PR is inheriting existing logic on master at the time btw. |
||
| # send/recv | ||
| output.finished_sending = finished_sending if finished_sending else None | ||
| output.finished_recving = finished_recving if finished_recving else None | ||
|
|
||
| return output | ||
|
|
||
| def async_aggregate(self, | ||
| output_futures: Sequence[Future[ModelRunnerOutput]], | ||
| output_rank: int = 0) -> Future[ModelRunnerOutput]: | ||
| """Takes a list of futures and returns a single future which resolves | ||
| to the respective list of outputs.""" | ||
| result_future: Future[ModelRunnerOutput] = Future() | ||
|
|
||
| outputs: list[Optional[ModelRunnerOutput]] = [None | ||
| ] * len(output_futures) | ||
|
|
||
| def make_callback(idx): | ||
|
|
||
| def callback(fut): | ||
| if result_future.done(): | ||
| return | ||
|
|
||
| try: | ||
| outputs[idx] = fut.result() | ||
| except CancelledError: | ||
| result_future.cancel() | ||
| except Exception as e: | ||
| result_future.set_exception(e) | ||
|
|
||
| # this check assumes io_thread_pool uses a single thread | ||
| if all(outputs): | ||
| result_future.set_result( | ||
| self.aggregate(cast(list[ModelRunnerOutput], outputs), | ||
| output_rank)) | ||
|
|
||
| return callback | ||
|
|
||
| for i, output_future in enumerate(output_futures): | ||
| output_future.add_done_callback(make_callback(i)) | ||
|
|
||
| return result_future | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kouroshHakha could you elaborate on the resource cleanup problem with mp. Are there processes left running?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah the second parametrization regardless of whether it's "mp" or "ray" hits an OOM issue on gpu. I tried using
vllm.distributed.cleanup_dist_env_and_memorybut didn't quite work.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized this test originally did not run on mp, distributed_executor_backend used to be None so I changed the latest version to None to only add ray to what was covered before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, we should look into the mp thing as a follow-on I guess