diff --git a/tests/v1/disaggregated/__init__.py b/tests/v1/disaggregated/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/v1/disaggregated/test_request_manager.py b/tests/v1/disaggregated/test_request_manager.py new file mode 100644 index 000000000000..5535e773afe3 --- /dev/null +++ b/tests/v1/disaggregated/test_request_manager.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import types +from typing import Optional +from unittest.mock import AsyncMock, Mock + +import httpx +import pytest + +from vllm.distributed.disaggregated.factory import ( + DisaggregatedRequestManagerFactory) +from vllm.distributed.disaggregated.prefill_local_decode_remote_manager import ( + PrefillLocalDecodeRemoteManager) +from vllm.distributed.disaggregated.request_manager import ( + DisaggregatedRequestManager) +from vllm.distributed.disaggregated.server_mixin import ( + DisaggregatedServerMixin) +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.v1.request import Request + + +@pytest.fixture(autouse=True) +def _isolate_factory_registry(monkeypatch): + original = DisaggregatedRequestManagerFactory._registry.copy() + monkeypatch.setattr(DisaggregatedRequestManagerFactory, + "_registry", {}, + raising=False) + try: + yield + finally: + DisaggregatedRequestManagerFactory._registry = original + + +def _dummy_config(enabled: bool): + cfg = types.SimpleNamespace() + cfg.kv_transfer_config = object() if enabled else None + return cfg + + +class DummyManagerRemoteOnly(DisaggregatedRequestManager): + priority = 1 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._call_count = 0 + + async def dispatch_request(self, request: Request, + local_output: Optional[RequestOutput], + client: httpx.AsyncClient, + local_executed: bool): + if local_executed and request.kv_transfer_params.get( + "do_remote_decode", False): + self._call_count += 1 + return True, {"ok": True} + return False, None + + +def test_register_request_manager(): + # Fresh registry + assert DisaggregatedRequestManagerFactory._registry == {} + + @DisaggregatedRequestManagerFactory.register("low") + class _DummyManagerLow(DisaggregatedRequestManager): + priority = 1 + + def dispatch_request(self, *args, **kwargs): + return False, None + + class _DummyManagerHigh(DisaggregatedRequestManager): + priority = 0 + + def dispatch_request(self, *args, **kwargs): + return False, None + + # Register the two managers + DisaggregatedRequestManagerFactory.register_request_manager( + "high", lambda cfg: _DummyManagerHigh(cfg)) + + assert set(DisaggregatedRequestManagerFactory._registry.keys()) == { + "low", "high" + } + + # Duplicate registration should fail + with pytest.raises(ValueError): + DisaggregatedRequestManagerFactory.register_request_manager( + "low", lambda cfg: _DummyManagerLow(cfg)) + + # Creation should order by ascending priority (high comes before low) + cfg = _dummy_config(enabled=True) + managers = DisaggregatedRequestManagerFactory.create_request_managers(cfg) + assert isinstance(managers[0], _DummyManagerHigh) + assert isinstance(managers[1], _DummyManagerLow) + + +@pytest.mark.asyncio +async def test_basic_disaggregated_server_mixin_lifecycle(monkeypatch): + """ + Basic lifecycle test showcasing how DisaggregatedServerMixin is meant + to plug into an endpoint. + """ + DisaggregatedRequestManagerFactory.register_request_manager( + "dummy", lambda cfg: DummyManagerRemoteOnly(cfg)) + + # Disabled path: no managers, context manager yields None + m_disabled = DisaggregatedServerMixin(vllm_config=_dummy_config( + enabled=False)) + m_disabled.maybe_setup_disaggregated_server() + assert m_disabled.managers == [] + + dummy_req = Request( + request_id="r1", + prompt_token_ids=[1, 2], + sampling_params=SamplingParams(max_tokens=1), + pooling_params=None, + eos_token_id=None, + ) + dummy_req.kv_transfer_params = dict( + # Additional custom params for OOT plugins go here + do_remote_decode=True, ) + # Enabled path: setup should call factory and set managers + m_enabled = DisaggregatedServerMixin(vllm_config=_dummy_config( + enabled=True)) + m_enabled.maybe_setup_disaggregated_server() + assert len(m_enabled.managers) == 1 + manager = m_enabled.managers[0] + assert isinstance(manager, DummyManagerRemoteOnly) + + # 0. Run manager before_local hook: maybe_run_before_local + res = await m_enabled.maybe_run_disaggregated_before_local(dummy_req) + assert res is None + assert manager._call_count == 0 + # 1. Run local (generate..), mock here + local_output = RequestOutput(request_id="r1", + prompt="", + prompt_token_ids=[1, 2], + prompt_logprobs=None, + outputs=[], + finished=True) + # 2. Run manager after_local hook: maybe_run_after_local + res = await m_enabled.maybe_run_disaggregated_after_local( + dummy_req, local_output) + assert manager._call_count == 1 + # 3. Return result, manager has priority if remote returns, mock here + assert res + + +@pytest.mark.asyncio +async def test_prefill_local_decode_remote_manager(): + manager = PrefillLocalDecodeRemoteManager(_dummy_config(enabled=True)) + + def _make_request(request_id: str) -> Request: + req = Request( + request_id=request_id, + prompt_token_ids=[1, 2], + sampling_params=SamplingParams(max_tokens=16), + pooling_params=None, + eos_token_id=None, + ) + return req + + # Remote decode requests should fall through without modification. + non_disagg_req = _make_request("r-non-disagg") + non_disagg_req.kv_transfer_params = {"do_remote_decode": True} + dispatched, response = await manager.dispatch_request(non_disagg_req, + None, + AsyncMock(), + local_executed=False) + assert not dispatched + assert response is None + assert non_disagg_req.max_tokens == 16 + + # Prefill stage should disable streaming and clamp token limits locally. + prefill_req = _make_request("r-prefill") + prefill_req.kv_transfer_params = {} + dispatched, response = await manager.dispatch_request(prefill_req, + None, + AsyncMock(), + local_executed=False) + assert dispatched is True + assert response is None + assert not prefill_req.stream + assert prefill_req.max_tokens == 1 + assert prefill_req.max_completion_tokens == 1 + assert prefill_req.stream_options is None + + # After local prefill completes, manager should forward to remote decode + decode_req = _make_request("r-decode") + decode_req.kv_transfer_params = {} + remote_params = {"remote_host": "127.0.0.1", "remote_port": 8080} + local_output = types.SimpleNamespace(kv_transfer_params=remote_params) + + # Mock http client and call + client = AsyncMock() + response_payload = {"ok": True} + dummy_response = Mock() + dummy_response.raise_for_status = Mock() + dummy_response.json = Mock(return_value=response_payload) + client.post.return_value = dummy_response + + dispatched, response = await manager.dispatch_request(decode_req, + local_output, + client, + local_executed=True) + + client.post.assert_awaited_once_with("", json=decode_req) + dummy_response.raise_for_status.assert_called_once() + dummy_response.json.assert_called_once() + assert dispatched + assert response == response_payload + assert decode_req.kv_transfer_params == remote_params diff --git a/vllm/distributed/disaggregated/__init__.py b/vllm/distributed/disaggregated/__init__.py new file mode 100644 index 000000000000..3e00a3463660 --- /dev/null +++ b/vllm/distributed/disaggregated/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import AsyncGenerator +from typing import Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionResponse, + ErrorResponse) + +# TODO generateresponse +GenerationResponseT = Union[AsyncGenerator[str, None], ChatCompletionResponse, + ErrorResponse] \ No newline at end of file diff --git a/vllm/distributed/disaggregated/decode_local_prefill_remote_manager.py b/vllm/distributed/disaggregated/decode_local_prefill_remote_manager.py new file mode 100644 index 000000000000..4d5df52e8769 --- /dev/null +++ b/vllm/distributed/disaggregated/decode_local_prefill_remote_manager.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import httpx + +from vllm.config import VllmConfig +from vllm.distributed.disaggregated.request_manager import ( + DisaggregatedRequestManager) +from vllm.v1.request import Request + + +class DecodeLocalPrefillRemoteManager(DisaggregatedRequestManager): + priority = 1 + + def __init__(self, vllm_config: VllmConfig): + super().__init__(vllm_config) + self.prefill_clients = {} + + def dispatch_request(self, request: Request, + shared_http_clients: dict[str, httpx.AsyncClient], + local_executed: bool): + kv_params = request.kv_transfer_params + assert kv_params is not None + if kv_params.get("do_remote_prefill", False): + # Non-disaggregated request + return False, None + + if not local_executed: + # Send request to prefill server right away before local Decode + # Keep connection open so that a single disconnect allows the whole + # chain to be closed and cleaned up. + pass + return True, None + + # Let local decode run through + # TODO add the computed token to the prompt + return True, None diff --git a/vllm/distributed/disaggregated/factory.py b/vllm/distributed/disaggregated/factory.py new file mode 100644 index 000000000000..27340a8c4dc1 --- /dev/null +++ b/vllm/distributed/disaggregated/factory.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable, Optional + +from vllm.config import VllmConfig +from vllm.distributed.disaggregated.request_manager import ( + DisaggregatedRequestManager) + + +class DisaggregatedRequestManagerFactory: + _registry: dict[str, Callable[..., DisaggregatedRequestManager]] = {} + + @classmethod + def register_request_manager( + cls, class_name: str, + ctor: Callable[..., DisaggregatedRequestManager]) -> None: + """Register a request manager along with its constructor.""" + if class_name in cls._registry: + raise ValueError( + f"Request manager '{class_name}' is already registered.") + + cls._registry[class_name] = ctor + + @classmethod + def create_request_managers( + cls, + config: "VllmConfig", + ) -> list[DisaggregatedRequestManager]: + + managers = [] + for manager_ctor in cls._registry.values(): + manager = manager_ctor(config) + managers.append((manager.priority, manager)) + return [manager for _, manager in sorted(managers, key=lambda x: x[0])] + + @classmethod + def register(cls, name: Optional[str] = None): + """Class decorator to register a `DisaggregatedRequestManager`. + + Usage: + @DisaggregatedRequestManagerFactory.register("MyManager") + class MyManager(DisaggregatedRequestManager): + ... + If `name` is None, the class' __name__ is used. + """ + + def _decorator(manager_cls: type[DisaggregatedRequestManager]): + register_name = name or manager_cls.__name__ + + def _ctor(config: VllmConfig) -> DisaggregatedRequestManager: + return manager_cls(config) + + cls.register_request_manager(register_name, _ctor) + return manager_cls + + return _decorator diff --git a/vllm/distributed/disaggregated/prefill_local_decode_remote_manager.py b/vllm/distributed/disaggregated/prefill_local_decode_remote_manager.py new file mode 100644 index 000000000000..3ba2b63f2d99 --- /dev/null +++ b/vllm/distributed/disaggregated/prefill_local_decode_remote_manager.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import httpx + +from vllm.config import VllmConfig +from vllm.distributed.disaggregated.request_manager import ( + DisaggregatedRequestManager) +from vllm.outputs import RequestOutput +from vllm.v1.request import Request + + +class PrefillLocalDecodeRemoteManager(DisaggregatedRequestManager): + priority = 1 + + def __init__(self, vllm_config: VllmConfig): + super().__init__(vllm_config) + self.prefill_clients = {} + + async def dispatch_request(self, request: Request, + local_output: Optional[RequestOutput], + client: httpx.AsyncClient, + local_executed: bool): + kv_params = request.kv_transfer_params + assert kv_params is not None + if kv_params.get("do_remote_decode", False): + # Non-disaggregated request + return False, None + + if local_executed: + # Local Prefill completed: open a new connection to the remote + # decode server and stream response back to the client + assert local_output.kv_transfer_params is not None + # Contains P->D transfer params + request.kv_transfer_params = local_output.kv_transfer_params + + # TODO deffered decode selection here + host = request.kv_transfer_params["remote_host"] + port = request.kv_transfer_params["remote_port"] + assert host and port + + # TODO return streaming generator + response = await client.post("", json=request) + response.raise_for_status() + return True, response.json() + + # Let local prefill run through, but make sure streaming is disabled + # TODO save to figure out if streaming is needed + request.stream = False + request.max_tokens = 1 + request.max_completion_tokens = 1 + request.stream_options = None + + return True, None diff --git a/vllm/distributed/disaggregated/request_manager.py b/vllm/distributed/disaggregated/request_manager.py new file mode 100644 index 000000000000..3151e85156fe --- /dev/null +++ b/vllm/distributed/disaggregated/request_manager.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Optional + +import httpx + +from vllm.distributed.disaggregated import GenerationResponseT +from vllm.outputs import RequestOutput +from vllm.v1.request import Request + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + +class DisaggregatedRequestManager(ABC): + priority: int = 0 + + def __init__(self, vllm_config: "VllmConfig"): + self._vllm_config = vllm_config + + @abstractmethod + async def dispatch_request( + self, + request: Request, + local_output: Optional[RequestOutput], + client: httpx.AsyncClient, + local_executed: bool = False + ) -> tuple[bool, Optional[GenerationResponseT]]: + # First flag to indicate if the request is successfully dispatched + pass diff --git a/vllm/distributed/disaggregated/server_mixin.py b/vllm/distributed/disaggregated/server_mixin.py new file mode 100644 index 000000000000..fe028dff557c --- /dev/null +++ b/vllm/distributed/disaggregated/server_mixin.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from contextlib import AbstractContextManager, nullcontext +from typing import Optional + +import httpx + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.distributed.disaggregated import GenerationResponseT +from vllm.distributed.disaggregated.factory import ( + DisaggregatedRequestManagerFactory) +from vllm.distributed.disaggregated.request_manager import ( + DisaggregatedRequestManager) +from vllm.outputs import RequestOutput +from vllm.v1.request import Request + + +class DisaggregatedServerMixin: + + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + self.client = httpx.AsyncClient() + # TODO Logic for enabling might be more complex and platform-dependent + self._enabled = self.vllm_config.kv_transfer_config is not None + self.managers = list[DisaggregatedRequestManager]() + # TODO allow different API keys for different remotes + self._api_key = envs.VLLM_API_KEY + + def maybe_setup_disaggregated_server(self): + if not self._enabled: + return + + # Initialize managers, ordered by priority for dispatching + self.managers = DisaggregatedRequestManagerFactory.\ + create_request_managers(self.vllm_config) + + async def _maybe_run_disaggregated( + self, + request: Request, + local_output: Optional[RequestOutput], + local_executed: bool, + ) -> AbstractContextManager[Optional[GenerationResponseT]]: + if not self._enabled or request.kv_transfer_params is None: + return nullcontext() + + # Setup shared client + if kv_params := request.kv_transfer_params: + host = kv_params.get("remote_host") + port = kv_params.get("remote_port") + if host and port: + # TODO set other endpoints + self.client.base_url = f"http://{host}:{port}/v1/chat/completions" + self.client.headers.update({ + "Authorization": f"Bearer {self._api_key}", + "X-Request-Id": request.request_id + }) + + # Dispatch to Manager with highest priority + response = None + for manager in self.managers: + success, response = await manager.dispatch_request( + request, + local_output, + self.client, + local_executed=local_executed) + if success: + break + + return response + + async def maybe_run_disaggregated_before_local( + self, + request: Request, + ) -> Optional[GenerationResponseT]: + return await self._maybe_run_disaggregated(request, None, False) + + async def maybe_run_disaggregated_after_local( + self, + request: Request, + local_output: RequestOutput, + ) -> AbstractContextManager[Optional[GenerationResponseT]]: + return await self._maybe_run_disaggregated(request, local_output, True) + + def shutdown(self): + self.client.aclose()