|
12 | 12 | from concurrent.futures import Future |
13 | 13 | from dataclasses import dataclass, field |
14 | 14 | from threading import Thread |
15 | | -from typing import Any, Callable, Optional, Union |
| 15 | +from typing import Any, Callable, Optional, TypeVar, Union |
16 | 16 |
|
17 | 17 | import zmq |
18 | 18 | import zmq.asyncio |
|
33 | 33 |
|
34 | 34 | AnyFuture = Union[asyncio.Future[Any], Future[Any]] |
35 | 35 |
|
| 36 | +_R = TypeVar('_R') # Return type for collective_rpc |
| 37 | + |
36 | 38 |
|
37 | 39 | class EngineCoreClient(ABC): |
38 | 40 | """ |
@@ -117,6 +119,13 @@ def list_loras(self) -> set[int]: |
117 | 119 | def pin_lora(self, lora_id: int) -> bool: |
118 | 120 | raise NotImplementedError |
119 | 121 |
|
| 122 | + def collective_rpc(self, |
| 123 | + method: Union[str, Callable[..., _R]], |
| 124 | + timeout: Optional[float] = None, |
| 125 | + args: tuple = (), |
| 126 | + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: |
| 127 | + raise NotImplementedError |
| 128 | + |
120 | 129 | async def get_output_async(self) -> EngineCoreOutputs: |
121 | 130 | raise NotImplementedError |
122 | 131 |
|
@@ -153,6 +162,14 @@ async def list_loras_async(self) -> set[int]: |
153 | 162 | async def pin_lora_async(self, lora_id: int) -> bool: |
154 | 163 | raise NotImplementedError |
155 | 164 |
|
| 165 | + async def collective_rpc_async( |
| 166 | + self, |
| 167 | + method: Union[str, Callable[..., _R]], |
| 168 | + timeout: Optional[float] = None, |
| 169 | + args: tuple = (), |
| 170 | + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: |
| 171 | + raise NotImplementedError |
| 172 | + |
156 | 173 |
|
157 | 174 | class InprocClient(EngineCoreClient): |
158 | 175 | """ |
@@ -210,6 +227,13 @@ def list_loras(self) -> set[int]: |
210 | 227 | def pin_lora(self, lora_id: int) -> bool: |
211 | 228 | return self.engine_core.pin_lora(lora_id) |
212 | 229 |
|
| 230 | + def collective_rpc(self, |
| 231 | + method: Union[str, Callable[..., _R]], |
| 232 | + timeout: Optional[float] = None, |
| 233 | + args: tuple = (), |
| 234 | + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: |
| 235 | + return self.engine_core.collective_rpc(method, timeout, args, kwargs) |
| 236 | + |
213 | 237 |
|
214 | 238 | class CoreEngine: |
215 | 239 | """One per data parallel rank.""" |
@@ -505,6 +529,14 @@ def is_sleeping(self) -> bool: |
505 | 529 | def execute_dummy_batch(self) -> None: |
506 | 530 | self.call_utility("execute_dummy_batch") |
507 | 531 |
|
| 532 | + def collective_rpc(self, |
| 533 | + method: Union[str, Callable[..., _R]], |
| 534 | + timeout: Optional[float] = None, |
| 535 | + args: tuple = (), |
| 536 | + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: |
| 537 | + return self.call_utility("collective_rpc", method, timeout, args, |
| 538 | + kwargs) |
| 539 | + |
508 | 540 |
|
509 | 541 | class AsyncMPClient(MPClient): |
510 | 542 | """Asyncio-compatible client for multi-proc EngineCore.""" |
@@ -636,6 +668,15 @@ async def list_loras_async(self) -> set[int]: |
636 | 668 | async def pin_lora_async(self, lora_id: int) -> bool: |
637 | 669 | return await self.call_utility_async("pin_lora", lora_id) |
638 | 670 |
|
| 671 | + async def collective_rpc_async( |
| 672 | + self, |
| 673 | + method: Union[str, Callable[..., _R]], |
| 674 | + timeout: Optional[float] = None, |
| 675 | + args: tuple = (), |
| 676 | + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: |
| 677 | + return await self.call_utility_async("collective_rpc", method, timeout, |
| 678 | + args, kwargs) |
| 679 | + |
639 | 680 |
|
640 | 681 | class DPAsyncMPClient(AsyncMPClient): |
641 | 682 | """Asyncio-compatible client for multi-proc, multi-engine (data parallel) |
|
0 commit comments