diff --git a/tests/entrypoints/openai_api/test_openpi_connection.py b/tests/entrypoints/openai_api/test_openpi_connection.py new file mode 100644 index 00000000000..ba5a45c3454 --- /dev/null +++ b/tests/entrypoints/openai_api/test_openpi_connection.py @@ -0,0 +1,270 @@ +import asyncio +import builtins +import sys +import types +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from vllm_omni.entrypoints.openpi import connection as openpi_connection +from vllm_omni.entrypoints.openpi.serving import PolicyServerConfig + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class FakeWebSocket: + def __init__(self, messages): + self._messages = list(messages) + self.sent_bytes = [] + self.sent_texts = [] + self.accepted = False + self.closed = False + + async def accept(self): + self.accepted = True + + async def send_bytes(self, data): + self.sent_bytes.append(data) + + async def send_text(self, data): + self.sent_texts.append(data) + + async def receive(self): + return self._messages.pop(0) + + async def close(self): + self.closed = True + + +def _serving_mock(): + serving = MagicMock() + serving.policy_server_config = PolicyServerConfig( + { + "image_resolution": (180, 320), + "n_external_cameras": 2, + "needs_wrist_camera": True, + "needs_stereo_camera": False, + "needs_session_id": True, + "action_space": "joint_position", + } + ) + serving.infer = AsyncMock(return_value=[0.0]) + return serving + + +def test_pack_reports_clear_error_when_openpi_client_is_missing(monkeypatch): + real_import = builtins.__import__ + + def import_without_openpi_client(name, globals=None, locals=None, fromlist=(), level=0): + if name == "openpi_client": + raise ModuleNotFoundError("No module named 'openpi_client'", name="openpi_client") + return real_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", import_without_openpi_client) + + with pytest.raises(ImportError) as exc_info: + openpi_connection._pack({"prompt": "pick up the object"}) + + message = str(exc_info.value) + assert "/v1/realtime/robot/openpi" in message + assert "pip install openpi-client" in message + + +def test_pack_and_unpack_delegate_to_openpi_msgpack_numpy(monkeypatch): + calls = [] + + class FakeMsgpackNumpy: + @staticmethod + def packb(obj): + calls.append(("packb", obj)) + return b"packed" + + @staticmethod + def unpackb(data): + calls.append(("unpackb", data)) + return {"unpacked": data} + + fake_openpi_client = types.ModuleType("openpi_client") + fake_openpi_client.msgpack_numpy = FakeMsgpackNumpy + monkeypatch.setitem(sys.modules, "openpi_client", fake_openpi_client) + + assert openpi_connection._pack({"x": 1}) == b"packed" + assert openpi_connection._unpack(b"payload") == {"unpacked": b"payload"} + assert calls == [ + ("packb", {"x": 1}), + ("unpackb", b"payload"), + ] + + +def test_handle_connection_returns_structured_error_for_invalid_payload(monkeypatch): + monkeypatch.setattr(openpi_connection, "_pack", lambda obj: obj) + monkeypatch.setattr( + openpi_connection, + "_unpack", + lambda _data: (_ for _ in ()).throw(ValueError("bad payload traceback")), + ) + + websocket = FakeWebSocket( + [ + {"type": "websocket.receive", "bytes": b"bad"}, + {"type": "websocket.disconnect"}, + ] + ) + serving = MagicMock() + connection = openpi_connection.RobotRealtimeConnection(websocket, serving) + + asyncio.run(connection.handle_connection()) + + assert websocket.accepted is True + assert websocket.sent_bytes[1] == {"type": "error", "message": "Invalid request payload"} + assert "traceback" not in str(websocket.sent_bytes[1]).lower() + assert websocket.sent_texts == [] + serving.infer.assert_not_called() + serving.reset.assert_not_called() + + +def test_handle_connection_rejects_oversized_payload_before_unpack(monkeypatch): + unpack_mock = MagicMock(side_effect=AssertionError("_unpack should not be called")) + monkeypatch.setattr(openpi_connection, "_pack", lambda obj: obj) + monkeypatch.setattr(openpi_connection, "_unpack", unpack_mock) + monkeypatch.setattr(openpi_connection, "MAX_OPENPI_PAYLOAD_BYTES", 4) + + websocket = FakeWebSocket( + [ + {"type": "websocket.receive", "bytes": b"too-large"}, + {"type": "websocket.disconnect"}, + ] + ) + serving = MagicMock() + connection = openpi_connection.RobotRealtimeConnection(websocket, serving) + + asyncio.run(connection.handle_connection()) + + assert websocket.sent_bytes[1] == {"type": "error", "message": "Invalid request payload"} + unpack_mock.assert_not_called() + serving.infer.assert_not_called() + serving.reset.assert_not_called() + + +def test_handle_connection_returns_structured_error_for_infer_exception(monkeypatch): + monkeypatch.setattr(openpi_connection, "_pack", lambda obj: obj) + monkeypatch.setattr( + openpi_connection, + "_unpack", + lambda _data: {"prompt": "pick up the object"}, + ) + + websocket = FakeWebSocket( + [ + {"type": "websocket.receive", "bytes": b"request"}, + {"type": "websocket.disconnect"}, + ] + ) + serving = MagicMock() + serving.infer = AsyncMock(side_effect=RuntimeError("secret traceback text")) + connection = openpi_connection.RobotRealtimeConnection(websocket, serving) + + asyncio.run(connection.handle_connection()) + + assert websocket.sent_bytes[1] == {"type": "error", "message": "Internal inference error"} + assert "secret traceback text" not in str(websocket.sent_bytes[1]) + assert websocket.sent_texts == [] + serving.infer.assert_awaited_once_with( + {"prompt": "pick up the object"}, + session_id="default", + reset=True, + ) + + +def test_handle_connection_closes_websocket_on_idle_timeout(monkeypatch): + monkeypatch.setattr(openpi_connection, "_pack", lambda obj: obj) + + websocket = FakeWebSocket([]) + + async def never_receives(): + await asyncio.sleep(1) + + websocket.receive = never_receives + serving = MagicMock() + serving.policy_server_config = PolicyServerConfig( + { + "image_resolution": (180, 320), + "n_external_cameras": 2, + "needs_wrist_camera": True, + "needs_stereo_camera": False, + "needs_session_id": True, + "action_space": "joint_position", + } + ) + connection = openpi_connection.RobotRealtimeConnection( + websocket, + serving, + idle_timeout=0.01, + ) + + asyncio.run(connection.handle_connection()) + + assert websocket.accepted is True + assert websocket.sent_bytes[0]["action_space"] == "joint_position" + assert websocket.closed is True + assert websocket.sent_texts == [] + serving.infer.assert_not_called() + + +def test_handle_connection_keeps_session_state_per_websocket(monkeypatch): + monkeypatch.setattr(openpi_connection, "_pack", lambda obj: obj) + requests = { + b"a1": {"prompt": "first", "session_id": "session-a"}, + b"a2": {"prompt": "second", "session_id": "session-a"}, + b"b1": {"prompt": "other", "session_id": "session-b"}, + } + monkeypatch.setattr(openpi_connection, "_unpack", lambda data: dict(requests[data])) + serving = _serving_mock() + + websocket_a = FakeWebSocket( + [ + {"type": "websocket.receive", "bytes": b"a1"}, + {"type": "websocket.receive", "bytes": b"a2"}, + {"type": "websocket.disconnect"}, + ] + ) + websocket_b = FakeWebSocket( + [ + {"type": "websocket.receive", "bytes": b"b1"}, + {"type": "websocket.disconnect"}, + ] + ) + + asyncio.run(openpi_connection.RobotRealtimeConnection(websocket_a, serving).handle_connection()) + asyncio.run(openpi_connection.RobotRealtimeConnection(websocket_b, serving).handle_connection()) + + calls = serving.infer.await_args_list + assert calls[0].kwargs == {"session_id": "session-a", "reset": True} + assert calls[1].kwargs == {"session_id": "session-a", "reset": False} + assert calls[2].kwargs == {"session_id": "session-b", "reset": True} + + +def test_handle_connection_reset_endpoint_resets_next_infer(monkeypatch): + monkeypatch.setattr(openpi_connection, "_pack", lambda obj: obj) + requests = { + b"a1": {"prompt": "first", "session_id": "session-a"}, + b"reset": {"endpoint": "reset"}, + b"a2": {"prompt": "second", "session_id": "session-a"}, + } + monkeypatch.setattr(openpi_connection, "_unpack", lambda data: dict(requests[data])) + serving = _serving_mock() + websocket = FakeWebSocket( + [ + {"type": "websocket.receive", "bytes": b"a1"}, + {"type": "websocket.receive", "bytes": b"reset"}, + {"type": "websocket.receive", "bytes": b"a2"}, + {"type": "websocket.disconnect"}, + ] + ) + + asyncio.run(openpi_connection.RobotRealtimeConnection(websocket, serving).handle_connection()) + + assert [call.kwargs["reset"] for call in serving.infer.await_args_list] == [True, True] + serving.reset.assert_called_once_with({}) + assert websocket.sent_bytes[2] == {"status": "reset successful"} + assert websocket.sent_texts == [] diff --git a/tests/entrypoints/openai_api/test_openpi_serving.py b/tests/entrypoints/openai_api/test_openpi_serving.py new file mode 100644 index 00000000000..9eb1d5bfe0b --- /dev/null +++ b/tests/entrypoints/openai_api/test_openpi_serving.py @@ -0,0 +1,338 @@ +import asyncio +import json +import threading +from concurrent.futures import ThreadPoolExecutor +from types import SimpleNamespace + +import numpy as np +import pytest +from fastapi import FastAPI, WebSocket +from omegaconf import OmegaConf +from starlette.testclient import TestClient + +from vllm_omni.entrypoints.openpi import connection as openpi_connection +from vllm_omni.entrypoints.openpi import serving as openpi_serving + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + +TEST_POLICY_SERVER_CONFIG = { + "image_resolution": (180, 320), + "n_external_cameras": 2, + "needs_wrist_camera": True, + "needs_stereo_camera": False, + "needs_session_id": True, + "action_space": "joint_position", +} + + +def _json_default(obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") + + +def _json_pack(obj): + return json.dumps(obj, default=_json_default).encode() + + +def _json_unpack(data): + return json.loads(data.decode()) + + +def _engine_with_policy_config(policy_config=None): + od_config = SimpleNamespace(model_config={"policy_server_config": policy_config or TEST_POLICY_SERVER_CONFIG}) + return SimpleNamespace(get_diffusion_od_config=lambda: od_config) + + +class RecordingEngine: + def __init__(self): + self.od_config = SimpleNamespace(model_config={"policy_server_config": TEST_POLICY_SERVER_CONFIG}) + self.generate_calls = [] + + def get_diffusion_od_config(self): + return self.od_config + + def generate(self, *, prompt, request_id, sampling_params_list): + async def _generate(): + self.generate_calls.append( + { + "prompt": prompt, + "request_id": request_id, + "sampling_params_list": sampling_params_list, + } + ) + yield SimpleNamespace(multimodal_output={"actions": [0.0]}) + + return _generate() + + +class ConcurrentRecordingEngine(RecordingEngine): + def __init__(self, *, expected_calls: int): + super().__init__() + self.expected_calls = expected_calls + self.condition = threading.Condition() + self.saw_overlap = False + + def _wait_for_expected_calls(self): + with self.condition: + completed = self.condition.wait_for( + lambda: len(self.generate_calls) >= self.expected_calls, + timeout=5.0, + ) + self.saw_overlap = self.saw_overlap or completed + + def generate(self, *, prompt, request_id, sampling_params_list): + async def _generate(): + with self.condition: + self.generate_calls.append( + { + "prompt": prompt, + "request_id": request_id, + "sampling_params_list": sampling_params_list, + } + ) + if len(self.generate_calls) >= self.expected_calls: + self.saw_overlap = True + self.condition.notify_all() + + await asyncio.to_thread(self._wait_for_expected_calls) + yield SimpleNamespace(multimodal_output={"actions": [0.0]}) + + return _generate() + + +def test_policy_server_config_reads_diffusion_model_config(): + policy_config = { + "image_resolution": [64, 64], + "n_external_cameras": 1, + "custom_model_key": {"nested": True}, + } + od_config = SimpleNamespace(model_config={"policy_server_config": policy_config}) + engine_client = SimpleNamespace(get_diffusion_od_config=lambda: od_config) + + serving = openpi_serving.ServingRealtimeRobotOpenPI(engine_client=engine_client) + + assert serving.policy_server_config.to_dict() == policy_config + + +def test_policy_server_config_reads_stage_config_model_config(): + policy_config = {"custom_model_key": "from-stage-config"} + engine_client = SimpleNamespace( + get_diffusion_od_config=lambda: None, + stage_configs=[ + SimpleNamespace( + stage_type="diffusion", + engine_args=SimpleNamespace(model_config={"policy_server_config": policy_config}), + ) + ], + ) + + serving = openpi_serving.ServingRealtimeRobotOpenPI(engine_client=engine_client) + + assert serving.policy_server_config.to_dict() == policy_config + + +def test_policy_server_config_reads_omegaconf_stage_config(): + engine_client = SimpleNamespace( + get_diffusion_od_config=lambda: None, + stage_configs=[ + SimpleNamespace( + stage_type="diffusion", + engine_args=SimpleNamespace( + model_config=OmegaConf.create({"policy_server_config": {"custom_model_key": "from-omegaconf"}}) + ), + ) + ], + ) + + serving = openpi_serving.ServingRealtimeRobotOpenPI(engine_client=engine_client) + + assert serving.policy_server_config.to_dict() == {"custom_model_key": "from-omegaconf"} + + +def test_policy_server_config_is_required(): + od_config = SimpleNamespace(model_config={}) + engine_client = SimpleNamespace(get_diffusion_od_config=lambda: od_config) + + with pytest.raises(ValueError) as exc_info: + openpi_serving.ServingRealtimeRobotOpenPI(engine_client=engine_client) + + assert "policy_server_config" in str(exc_info.value) + + +def test_create_policy_server_returns_none_without_policy_config(): + od_config = SimpleNamespace(model_config={}) + engine_client = SimpleNamespace(get_diffusion_od_config=lambda: od_config) + + serving = openpi_serving.ServingRealtimeRobotOpenPI.create_policy_server( + engine_client=engine_client, + model_name="generic-model", + ) + + assert serving is None + + +def test_policy_server_config_reads_engine_model_config(): + policy_config = {"custom_model_key": "custom-value"} + engine_client = SimpleNamespace(model_config=SimpleNamespace(policy_server_config=policy_config)) + + serving = openpi_serving.ServingRealtimeRobotOpenPI(engine_client=engine_client) + + assert serving.policy_server_config.to_dict() == policy_config + + +def test_build_request_uses_unique_engine_request_id_per_inference(): + serving = openpi_serving.ServingRealtimeRobotOpenPI(engine_client=_engine_with_policy_config()) + + request_a = serving._build_request( + {"prompt": "pick up the object"}, + session_id="session-a", + reset=True, + ) + request_b = serving._build_request( + {"prompt": "pick up the object"}, + session_id="session-a", + reset=False, + ) + + assert request_a.sampling_params.extra_args["reset"] is True + assert request_b.sampling_params.extra_args["reset"] is False + assert request_a.sampling_params.extra_args["session_id"] == "session-a" + assert request_b.sampling_params.extra_args["session_id"] == "session-a" + assert request_a.sampling_params.extra_args["robot_obs"]["prompt"] == "pick up the object" + assert request_b.sampling_params.extra_args["robot_obs"]["prompt"] == "pick up the object" + + assert request_a.request_id == "robot-session-a-0" + assert request_b.request_id == "robot-session-a-1" + assert request_a.request_id != request_b.request_id + + +def test_infer_keeps_session_state_but_uses_unique_engine_request_ids(): + engine = RecordingEngine() + serving = openpi_serving.ServingRealtimeRobotOpenPI(engine_client=engine) + + async def run_requests(): + await serving.infer({"prompt": "pick up the object"}, session_id="session-a", reset=True) + await serving.infer({"prompt": "pick up the object"}, session_id="session-a", reset=False) + + asyncio.run(run_requests()) + + assert [call["request_id"] for call in engine.generate_calls] == [ + "robot-session-a-0", + "robot-session-a-1", + ] + assert engine.generate_calls[0]["request_id"] != engine.generate_calls[1]["request_id"] + + sampling_params_a = engine.generate_calls[0]["sampling_params_list"][0] + sampling_params_b = engine.generate_calls[1]["sampling_params_list"][0] + assert sampling_params_a.extra_args["session_id"] == "session-a" + assert sampling_params_b.extra_args["session_id"] == "session-a" + assert sampling_params_a.extra_args["reset"] is True + assert sampling_params_b.extra_args["reset"] is False + + +def test_two_websocket_clients_without_session_id_do_not_conflict(monkeypatch): + monkeypatch.setattr(openpi_connection, "_pack", _json_pack) + monkeypatch.setattr(openpi_connection, "_unpack", _json_unpack) + + engine = ConcurrentRecordingEngine(expected_calls=2) + serving = openpi_serving.ServingRealtimeRobotOpenPI(engine_client=engine) + app = FastAPI() + + @app.websocket("/v1/realtime/robot/openpi") + async def openpi_endpoint(websocket: WebSocket): + connection = openpi_connection.RobotRealtimeConnection(websocket, serving) + await connection.handle_connection() + + def run_client(prompt: str): + with TestClient(app) as client: + with client.websocket_connect("/v1/realtime/robot/openpi") as websocket: + metadata = _json_unpack(websocket.receive_bytes()) + assert metadata["needs_session_id"] is True + + websocket.send_bytes(_json_pack({"prompt": prompt})) + actions = _json_unpack(websocket.receive_bytes()) + np.testing.assert_array_equal( + np.asarray(actions, dtype=np.float32), + np.asarray([0.0], dtype=np.float32), + ) + + with ThreadPoolExecutor(max_workers=2) as executor: + futures = [ + executor.submit(run_client, "first client"), + executor.submit(run_client, "second client"), + ] + for future in futures: + future.result(timeout=10.0) + + request_ids = [call["request_id"] for call in engine.generate_calls] + assert len(request_ids) == 2 + assert len(set(request_ids)) == 2 + assert all(request_id.startswith("robot-default-") for request_id in request_ids) + assert engine.saw_overlap is True + + sampling_params = [call["sampling_params_list"][0] for call in engine.generate_calls] + assert [params.extra_args["session_id"] for params in sampling_params] == ["default", "default"] + assert [params.extra_args["reset"] for params in sampling_params] == [True, True] + + +def test_infer_extracts_actions_from_generic_multimodal_output(): + class FakeEngineClient: + def get_diffusion_od_config(self): + return SimpleNamespace(model_config={"policy_server_config": TEST_POLICY_SERVER_CONFIG}) + + async def generate(self, **kwargs): + self.generate_kwargs = kwargs + yield SimpleNamespace(multimodal_output={"actions": [[1.0, 2.0, 3.0]]}) + + engine_client = FakeEngineClient() + serving = openpi_serving.ServingRealtimeRobotOpenPI(engine_client=engine_client) + + actions = asyncio.run(serving.infer({"prompt": "pick up"}, session_id="session-a", reset=True)) + + np.testing.assert_allclose(actions, np.array([[1.0, 2.0, 3.0]], dtype=np.float32)) + assert engine_client.generate_kwargs["prompt"] == "pick up" + assert engine_client.generate_kwargs["request_id"] == "robot-session-a-0" + + +def test_infer_preserves_dict_actions_from_multimodal_output(): + class FakeEngineClient: + def get_diffusion_od_config(self): + return SimpleNamespace(model_config={"policy_server_config": TEST_POLICY_SERVER_CONFIG}) + + async def generate(self, **kwargs): + self.generate_kwargs = kwargs + yield SimpleNamespace( + multimodal_output={ + "actions": { + "left_arm": [[1.0, 2.0]], + "right_arm": np.array([[3.0, 4.0]], dtype=np.float64), + } + } + ) + + engine_client = FakeEngineClient() + serving = openpi_serving.ServingRealtimeRobotOpenPI(engine_client=engine_client) + + actions = asyncio.run(serving.infer({"prompt": "pick up"}, session_id="session-a", reset=True)) + + assert isinstance(actions, dict) + assert set(actions) == {"left_arm", "right_arm"} + np.testing.assert_allclose(actions["left_arm"], np.array([[1.0, 2.0]], dtype=np.float32)) + np.testing.assert_allclose(actions["right_arm"], np.array([[3.0, 4.0]], dtype=np.float32)) + assert actions["left_arm"].dtype == np.float32 + assert actions["right_arm"].dtype == np.float32 + + +def test_extract_actions_does_not_iterate_result_object(): + class IterableResult: + multimodal_output = {"actions": [[1.0, 2.0, 3.0]]} + + def __iter__(self): + raise AssertionError("result object should not be iterated") + + serving = openpi_serving.ServingRealtimeRobotOpenPI(engine_client=_engine_with_policy_config()) + + actions = serving._extract_actions(IterableResult()) + + np.testing.assert_allclose(actions, np.array([[1.0, 2.0, 3.0]], dtype=np.float32)) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 4db31efbbd3..d37a2374ab5 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -133,6 +133,7 @@ from vllm_omni.entrypoints.openai.stores import VIDEO_STORE, VIDEO_TASKS from vllm_omni.entrypoints.openai.utils import get_stage_type, parse_lora_request from vllm_omni.entrypoints.openai.video_api_utils import decode_input_reference +from vllm_omni.entrypoints.openpi.serving import ServingRealtimeRobotOpenPI from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt logger = init_logger(__name__) @@ -633,6 +634,10 @@ async def omni_init_app_state( ) state.openai_streaming_speech = None state.openai_streaming_video = None + state.openai_serving_realtime_robot = ServingRealtimeRobotOpenPI.create_policy_server( + engine_client=engine_client, + model_name=model_name, + ) state.enable_server_load_tracking = getattr(args, "enable_server_load_tracking", False) state.server_load_metrics = 0 @@ -963,6 +968,7 @@ async def omni_init_app_state( model_name=served_model_names[0] if served_model_names else None, stage_configs=state.stage_configs, ) + state.openai_serving_realtime_robot = None state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0 @@ -1406,6 +1412,23 @@ async def realtime_websocket(websocket: WebSocket): await connection.handle_connection() +@router.websocket("/v1/realtime/robot/openpi") +async def realtime_robot_openpi(websocket: WebSocket): + """WebSocket endpoint for robot policy inference via OpenPI messages.""" + from vllm_omni.entrypoints.openpi.connection import ( + RobotRealtimeConnection, + ) + + serving = getattr(websocket.app.state, "openai_serving_realtime_robot", None) + if serving is None: + await websocket.accept() + await websocket.send_json({"type": "error", "error": "Robot policy not available", "code": "unsupported"}) + await websocket.close() + return + connection = RobotRealtimeConnection(websocket, serving) + await connection.handle_connection() + + # Health and Model endpoints for diffusion mode diff --git a/vllm_omni/entrypoints/openpi/__init__.py b/vllm_omni/entrypoints/openpi/__init__.py new file mode 100644 index 00000000000..9881313609a --- /dev/null +++ b/vllm_omni/entrypoints/openpi/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/vllm_omni/entrypoints/openpi/connection.py b/vllm_omni/entrypoints/openpi/connection.py new file mode 100644 index 00000000000..670c37a9233 --- /dev/null +++ b/vllm_omni/entrypoints/openpi/connection.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""WebSocket connection for robot policy inference (OpenPI protocol). + +Protocol (compatible with OpenPI policy clients): + Connect -> server sends msgpack(PolicyServerConfig fields) + Infer -> client sends msgpack(obs), server sends msgpack(ndarray) + Reset -> client sends msgpack({endpoint:reset}), server sends msgpack(status) +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +from fastapi import WebSocket +from starlette.websockets import WebSocketDisconnect +from vllm.logger import init_logger + +from vllm_omni.entrypoints.openpi.serving import ( + ServingRealtimeRobotOpenPI, +) + +logger = init_logger(__name__) +_DEFAULT_IDLE_TIMEOUT = 30.0 +MAX_OPENPI_PAYLOAD_BYTES = 64 * 1024 * 1024 + + +def _get_msgpack_numpy() -> Any: + try: + from openpi_client import msgpack_numpy + except ImportError as exc: + raise ImportError( + "The `/v1/realtime/robot/openpi` endpoint requires the optional " + "`openpi-client` dependency. Install it with `pip install openpi-client`." + ) from exc + + return msgpack_numpy + + +def _pack(obj: Any) -> bytes: + return _get_msgpack_numpy().packb(obj) + + +def _unpack(data: bytes) -> Any: + return _get_msgpack_numpy().unpackb(data) + + +class RobotRealtimeConnection: + """WebSocket connection for robot policy inference.""" + + def __init__( + self, + websocket: WebSocket, + serving: ServingRealtimeRobotOpenPI, + idle_timeout: float = _DEFAULT_IDLE_TIMEOUT, + ) -> None: + self.websocket = websocket + self.serving = serving + self._idle_timeout = idle_timeout + self._current_session_id: str | None = None + self._call_count = 0 + + def reset(self) -> None: + self._current_session_id = None + self._call_count = 0 + + async def _send_error(self, message: str) -> None: + await self.websocket.send_bytes(_pack({"type": "error", "message": message})) + + def _unpack_request(self, data: bytes) -> dict[str, Any]: + if len(data) > MAX_OPENPI_PAYLOAD_BYTES: + raise ValueError("OpenPI request payload too large") + obs = _unpack(data) + if not isinstance(obs, dict): + raise ValueError("Invalid request payload") + return obs + + async def handle_connection(self) -> None: + """Main loop for OpenPI-compatible policy serving.""" + await self.websocket.accept() + + try: + # Send model-specific PolicyServerConfig resolved by serving from + # diffusion od_config.model_config. + metadata = self.serving.policy_server_config.to_dict() + await self.websocket.send_bytes(_pack(metadata)) + + while True: + try: + msg = await asyncio.wait_for( + self.websocket.receive(), + timeout=self._idle_timeout, + ) + except asyncio.TimeoutError: + logger.info("Robot OpenPI connection idle timeout after %.1f seconds", self._idle_timeout) + try: + await self.websocket.close() + except Exception: + logger.debug("Failed to close idle robot OpenPI websocket", exc_info=True) + return + + if msg.get("type") == "websocket.disconnect": + break + + if "bytes" not in msg or not msg["bytes"]: + continue + + try: + obs = self._unpack_request(msg["bytes"]) + except Exception: + logger.exception("Invalid robot OpenPI request payload") + try: + await self._send_error("Invalid request payload") + except Exception: + break + continue + + try: + endpoint = obs.pop("endpoint", "infer") + + if endpoint == "reset": + self.reset() + self.serving.reset(obs) + await self.websocket.send_bytes(_pack({"status": "reset successful"})) + else: + session_id = str(obs.get("session_id") or self._current_session_id or "default") + if session_id != self._current_session_id: + if self._current_session_id is not None: + logger.info( + "Robot OpenPI session changed %s -> %s", + self._current_session_id, + session_id, + ) + self._current_session_id = session_id + self._call_count = 0 + + self._call_count += 1 + actions = await self.serving.infer( + obs, + session_id=session_id, + reset=self._call_count <= 1, + ) + await self.websocket.send_bytes(_pack(actions)) + except Exception: + logger.exception("Error handling request") + try: + await self._send_error("Internal inference error") + except Exception: + break + + except WebSocketDisconnect: + pass + except Exception: + logger.exception("Connection error") diff --git a/vllm_omni/entrypoints/openpi/serving.py b/vllm_omni/entrypoints/openpi/serving.py new file mode 100644 index 00000000000..ec5c88408ea --- /dev/null +++ b/vllm_omni/entrypoints/openpi/serving.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Serving layer for robot policy inference via `/v1/realtime/robot/openpi`. + +Flow: raw obs → engine request → actions. +The loaded policy model owns dataset transforms inside its pipeline. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from itertools import count +from typing import Any + +import numpy as np +from omegaconf import OmegaConf +from vllm.logger import init_logger + +logger = init_logger(__name__) + +ActionOutput = np.ndarray | dict[str, np.ndarray] + + +def _to_builtin_container(value: Any) -> Any: + if OmegaConf.is_config(value): + return OmegaConf.to_container(value, resolve=True) + if isinstance(value, Mapping): + return {key: _to_builtin_container(item) for key, item in value.items()} + if isinstance(value, (list, tuple)): + return [_to_builtin_container(item) for item in value] + return value + + +@dataclass(frozen=True) +class PolicyServerConfig: + """OpenPI policy server handshake config. + + Values are model-specific and must be provided by the loaded policy model. + """ + + values: dict[str, Any] + + @classmethod + def from_model_config(cls, model_config: Any) -> PolicyServerConfig: + if isinstance(model_config, Mapping): + raw_config = model_config.get("policy_server_config") + else: + raw_config = getattr(model_config, "policy_server_config", None) + + if raw_config is None: + raise ValueError("Robot OpenPI serving requires policy_server_config.") + if isinstance(raw_config, cls): + return raw_config + if not isinstance(raw_config, Mapping): + raise ValueError("Robot OpenPI serving requires policy_server_config.") + return cls(_to_builtin_container(raw_config)) + + def to_dict(self) -> dict[str, Any]: + return _to_builtin_container(self.values) + + +class ServingRealtimeRobotOpenPI: + """Robot policy serving layer for OpenPI protocol. + + Model-specific transform/state lives in the diffusion pipeline. + """ + + def __init__( + self, + engine_client: Any, + model_name: str | None = None, + ) -> None: + self.engine_client = engine_client + self.model_name = model_name + self.policy_server_config = self._get_policy_server_config(engine_client) + self._request_counter = count() + + @classmethod + def create_policy_server( + cls, + engine_client: Any, + model_name: str | None = None, + ) -> ServingRealtimeRobotOpenPI | None: + try: + return cls(engine_client=engine_client, model_name=model_name) + except ValueError as exc: + if "policy_server_config" not in str(exc): + raise + logger.info("Robot OpenPI serving disabled for model %s", model_name) + return None + + @staticmethod + def _get_policy_server_config(engine_client: Any) -> PolicyServerConfig: + model_config = None + get_od_config = getattr(engine_client, "get_diffusion_od_config", None) + if callable(get_od_config): + od_config = get_od_config() + model_config = getattr(od_config, "model_config", None) + + if model_config is None: + for stage_config in getattr(engine_client, "stage_configs", []) or []: + if getattr(stage_config, "stage_type", None) != "diffusion": + continue + engine_args = getattr(stage_config, "engine_args", None) + model_config = getattr(engine_args, "model_config", None) + if model_config is not None: + break + + if model_config is None: + od_config = getattr(engine_client, "od_config", None) + model_config = getattr(od_config, "model_config", None) + + if model_config is None: + model_config = getattr(engine_client, "model_config", None) + return PolicyServerConfig.from_model_config(model_config) + + def reset(self, obs: dict) -> None: + """Compatibility hook; per-connection state lives in RobotRealtimeConnection.""" + + async def infer(self, obs: dict, *, session_id: str, reset: bool) -> ActionOutput: + """raw obs → engine → actions.""" + # Build request, run inference through AsyncOmni + request = self._build_request(obs, session_id=session_id, reset=reset) + result = None + # OpenPI policy serving is one request -> one action reply. AsyncOmni + # exposes an async iterator, so consume it to completion and use the + # final output, matching other non-streaming OpenAI serving paths. + async for output in self.engine_client.generate( + prompt=request.prompts[0], + request_id=request.request_id, + sampling_params_list=[request.sampling_params], + ): + result = output + if result is None: + raise RuntimeError("Robot OpenPI request produced no output.") + + return self._extract_actions(result) + + def _next_request_id(self, session_id: str) -> str: + return f"robot-{session_id}-{next(self._request_counter)}" + + def _build_request(self, obs: dict, *, session_id: str, reset: bool) -> Any: + """Build engine request from raw robot obs. + + Returns an `OmniDiffusionRequest` payload consumed by + `AsyncOmni.generate()` and routed to the diffusion stage. + """ + from vllm_omni.diffusion.request import OmniDiffusionRequest + from vllm_omni.inputs.data import OmniDiffusionSamplingParams + + extra_args = { + "reset": reset, + "session_id": session_id, + "robot_obs": obs, + } + + prompt = obs.get("prompt", "") + sampling_params = OmniDiffusionSamplingParams(extra_args=extra_args) + return OmniDiffusionRequest( + prompts=[prompt], + sampling_params=sampling_params, + request_id=self._next_request_id(session_id), + ) + + def _extract_actions(self, result: Any) -> ActionOutput: + """Extract actions from engine result.""" + multimodal_output = getattr(result, "multimodal_output", None) + if not isinstance(multimodal_output, Mapping): + raise RuntimeError("Missing multimodal_output in robot policy result") + + actions = multimodal_output.get("actions") + if actions is None: + raise RuntimeError("Missing multimodal_output['actions'] in robot policy result") + if isinstance(actions, Mapping): + return {str(key): np.asarray(value, dtype=np.float32) for key, value in actions.items()} + return np.asarray(actions, dtype=np.float32)