Skip to content

Commit 94744ba

Browse files
authored
[V1] [Feature] Collective RPC (#15444)
Signed-off-by: wwl2755 <[email protected]>
1 parent 4965ec4 commit 94744ba

File tree

7 files changed

+86
-10
lines changed

7 files changed

+86
-10
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ steps:
150150
# TODO: create a dedicated test section for multi-GPU example tests
151151
# when we have multiple distributed example tests
152152
- pushd ../examples/offline_inference
153-
- VLLM_ENABLE_V1_MULTIPROCESSING=0 python3 rlhf.py
154-
- VLLM_ENABLE_V1_MULTIPROCESSING=0 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
153+
- python3 rlhf.py
154+
- RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
155155
- popd
156156

157157
- label: Metrics, Tracing Test # 10min
@@ -520,7 +520,7 @@ steps:
520520
- vllm/v1/engine/
521521
commands:
522522
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
523-
- VLLM_ENABLE_V1_MULTIPROCESSING=0 pytest -v -s entrypoints/llm/test_collective_rpc.py
523+
- pytest -v -s entrypoints/llm/test_collective_rpc.py
524524
- pytest -v -s ./compile/test_basic_correctness.py
525525
- pytest -v -s ./compile/test_wrapper.py
526526
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'

vllm/engine/llm_engine.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from contextlib import contextmanager
88
from dataclasses import dataclass
99
from functools import partial
10-
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
11-
List, Mapping, NamedTuple, Optional)
10+
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
11+
Iterable, List, Mapping, NamedTuple, Optional)
1212
from typing import Sequence as GenericSequence
1313
from typing import Set, Type, Union, cast, overload
1414

@@ -67,6 +67,7 @@
6767

6868
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
6969
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
70+
_R = TypeVar("_R", default=Any)
7071

7172

7273
@dataclass
@@ -2123,6 +2124,14 @@ def _build_logits_processors(
21232124

21242125
return sampling_params
21252126

2127+
def collective_rpc(self,
2128+
method: Union[str, Callable[..., _R]],
2129+
timeout: Optional[float] = None,
2130+
args: tuple = (),
2131+
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
2132+
return self.model_executor.collective_rpc(method, timeout, args,
2133+
kwargs)
2134+
21262135

21272136
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
21282137
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine

vllm/entrypoints/llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,8 +492,8 @@ def collective_rpc(self,
492492
It is recommended to use this API to only pass control messages,
493493
and set up data-plane communication to pass data.
494494
"""
495-
executor = self.llm_engine.model_executor
496-
return executor.collective_rpc(method, timeout, args, kwargs)
495+
496+
return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
497497

498498
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
499499
"""

vllm/v1/engine/core.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from concurrent.futures import Future
99
from inspect import isclass, signature
1010
from logging import DEBUG
11-
from typing import Any, Optional
11+
from typing import Any, Callable, Optional, TypeVar, Union
1212

1313
import msgspec
1414
import psutil
@@ -43,6 +43,8 @@
4343

4444
POLLING_TIMEOUT_S = 2.5
4545

46+
_R = TypeVar('_R') # Return type for collective_rpc
47+
4648

4749
class EngineCore:
4850
"""Inner loop of vLLM's Engine."""
@@ -280,6 +282,14 @@ def list_loras(self) -> set[int]:
280282
def pin_lora(self, lora_id: int) -> bool:
281283
return self.model_executor.pin_lora(lora_id)
282284

285+
def collective_rpc(self,
286+
method: Union[str, Callable[..., _R]],
287+
timeout: Optional[float] = None,
288+
args: tuple = (),
289+
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
290+
return self.model_executor.collective_rpc(method, timeout, args,
291+
kwargs)
292+
283293

284294
class EngineCoreProc(EngineCore):
285295
"""ZMQ-wrapper for running EngineCore in background process."""

vllm/v1/engine/core_client.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from concurrent.futures import Future
1313
from dataclasses import dataclass, field
1414
from threading import Thread
15-
from typing import Any, Callable, Optional, Union
15+
from typing import Any, Callable, Optional, TypeVar, Union
1616

1717
import zmq
1818
import zmq.asyncio
@@ -33,6 +33,8 @@
3333

3434
AnyFuture = Union[asyncio.Future[Any], Future[Any]]
3535

36+
_R = TypeVar('_R') # Return type for collective_rpc
37+
3638

3739
class EngineCoreClient(ABC):
3840
"""
@@ -117,6 +119,13 @@ def list_loras(self) -> set[int]:
117119
def pin_lora(self, lora_id: int) -> bool:
118120
raise NotImplementedError
119121

122+
def collective_rpc(self,
123+
method: Union[str, Callable[..., _R]],
124+
timeout: Optional[float] = None,
125+
args: tuple = (),
126+
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
127+
raise NotImplementedError
128+
120129
async def get_output_async(self) -> EngineCoreOutputs:
121130
raise NotImplementedError
122131

@@ -153,6 +162,14 @@ async def list_loras_async(self) -> set[int]:
153162
async def pin_lora_async(self, lora_id: int) -> bool:
154163
raise NotImplementedError
155164

165+
async def collective_rpc_async(
166+
self,
167+
method: Union[str, Callable[..., _R]],
168+
timeout: Optional[float] = None,
169+
args: tuple = (),
170+
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
171+
raise NotImplementedError
172+
156173

157174
class InprocClient(EngineCoreClient):
158175
"""
@@ -210,6 +227,13 @@ def list_loras(self) -> set[int]:
210227
def pin_lora(self, lora_id: int) -> bool:
211228
return self.engine_core.pin_lora(lora_id)
212229

230+
def collective_rpc(self,
231+
method: Union[str, Callable[..., _R]],
232+
timeout: Optional[float] = None,
233+
args: tuple = (),
234+
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
235+
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
236+
213237

214238
class CoreEngine:
215239
"""One per data parallel rank."""
@@ -505,6 +529,14 @@ def is_sleeping(self) -> bool:
505529
def execute_dummy_batch(self) -> None:
506530
self.call_utility("execute_dummy_batch")
507531

532+
def collective_rpc(self,
533+
method: Union[str, Callable[..., _R]],
534+
timeout: Optional[float] = None,
535+
args: tuple = (),
536+
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
537+
return self.call_utility("collective_rpc", method, timeout, args,
538+
kwargs)
539+
508540

509541
class AsyncMPClient(MPClient):
510542
"""Asyncio-compatible client for multi-proc EngineCore."""
@@ -636,6 +668,15 @@ async def list_loras_async(self) -> set[int]:
636668
async def pin_lora_async(self, lora_id: int) -> bool:
637669
return await self.call_utility_async("pin_lora", lora_id)
638670

671+
async def collective_rpc_async(
672+
self,
673+
method: Union[str, Callable[..., _R]],
674+
timeout: Optional[float] = None,
675+
args: tuple = (),
676+
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
677+
return await self.call_utility_async("collective_rpc", method, timeout,
678+
args, kwargs)
679+
639680

640681
class DPAsyncMPClient(AsyncMPClient):
641682
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)

vllm/v1/engine/llm_engine.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Mapping
44
from copy import copy
5-
from typing import Optional, Union
5+
from typing import Any, Callable, Optional, Union
66

77
from typing_extensions import TypeVar
88

@@ -32,6 +32,7 @@
3232
logger = init_logger(__name__)
3333

3434
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
35+
_R = TypeVar("_R", default=Any)
3536

3637

3738
class LLMEngine:
@@ -282,6 +283,13 @@ def pin_lora(self, lora_id: int) -> bool:
282283
"""Prevent an adapter from being evicted."""
283284
return self.engine_core.pin_lora(lora_id)
284285

286+
def collective_rpc(self,
287+
method: Union[str, Callable[..., _R]],
288+
timeout: Optional[float] = None,
289+
args: tuple = (),
290+
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
291+
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
292+
285293
def __del__(self):
286294
if dp_group := getattr(self, "dp_group", None):
287295
stateless_destroy_torch_distributed_process_group(dp_group)

vllm/v1/serial_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import pickle
4+
from types import FunctionType
45
from typing import Any, Optional
56

7+
import cloudpickle
68
import torch
79
from msgspec import msgpack
810

911
CUSTOM_TYPE_TENSOR = 1
1012
CUSTOM_TYPE_PICKLE = 2
13+
CUSTOM_TYPE_CLOUDPICKLE = 3
1114

1215

1316
class MsgpackEncoder:
@@ -41,6 +44,9 @@ def custom_enc_hook(obj: Any) -> Any:
4144
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
4245
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
4346

47+
if isinstance(obj, FunctionType):
48+
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
49+
4450
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
4551

4652

@@ -49,5 +55,7 @@ def custom_ext_hook(code: int, data: memoryview) -> Any:
4955
return torch.from_numpy(pickle.loads(data))
5056
if code == CUSTOM_TYPE_PICKLE:
5157
return pickle.loads(data)
58+
if code == CUSTOM_TYPE_CLOUDPICKLE:
59+
return cloudpickle.loads(data)
5260

5361
raise NotImplementedError(f"Extension type code {code} is not supported")

0 commit comments

Comments
 (0)