diff --git a/.buildkite/test-nightly.yml b/.buildkite/test-nightly.yml index 82e872581c7..9fb48dc1f3a 100644 --- a/.buildkite/test-nightly.yml +++ b/.buildkite/test-nightly.yml @@ -603,6 +603,48 @@ steps: path: /mnt/hf-cache type: DirectoryOrCreate + - label: ":full_moon: Diffusion X2I(&A&T) · Function Test with H100 · Multi-GPU DreamZero" + timeout_in_minutes: 120 + commands: + - >- + pytest -sv + tests/e2e/online_serving/test_dreamzero_expansion.py + -m "full_model and diffusion and H100 and distributed_cuda" + --run-level "full_model" + agents: + queue: "mithril-h100-pool" + plugins: + - kubernetes: + podSpec: + containers: + - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + resources: + limits: + nvidia.com/gpu: 2 + volumeMounts: + - name: devshm + mountPath: /dev/shm + - name: hf-cache + mountPath: /root/.cache/huggingface + env: + - name: HF_HOME + value: /root/.cache/huggingface + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + nodeSelector: + node.kubernetes.io/instance-type: gpu-h100-sxm + volumes: + - name: devshm + emptyDir: + medium: Memory + - name: hf-cache + hostPath: + path: /mnt/hf-cache + type: DirectoryOrCreate + - label: ":full_moon: Diffusion X2I(&A&T) · Function Test with L4" timeout_in_minutes: 60 diff --git a/examples/online_serving/dreamzero/droid_sim_eval_client.py b/examples/online_serving/dreamzero/droid_sim_eval_client.py new file mode 100644 index 00000000000..40f0f417cb5 --- /dev/null +++ b/examples/online_serving/dreamzero/droid_sim_eval_client.py @@ -0,0 +1,820 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Run one or more DROID sim-eval rollouts against the vLLM DreamZero server. + +This script is the vLLM/OpenPI adaptation of the upstream DreamZero sim-eval +client: +`third_party/dreamzero/eval_utils/run_sim_eval.py` + +Behavior intentionally kept close to upstream: +- same DROID observation extraction (`external_cam`, `external_cam_2`, + `wrist_cam`, joint position, gripper position) +- same resize-with-pad preprocessing to `(180, 320)` +- same `open_loop_horizon=8` +- same gripper binarization rule (`> 0.5 -> 1`, else `0`) +- same per-scene language prompts + +Unlike upstream DreamZero, vLLM serves the compatible websocket policy endpoint +at `/v1/realtime/robot/openpi`, so this script includes the path suffix in the +client URI. + +Run this script through Isaac Lab's launcher from the vLLM-Omni repository +root, for example: + + "${ISAACLAB_LAUNCHER}" -p \ + examples/online_serving/dreamzero/droid_sim_eval_client.py \ + --host 127.0.0.1 \ + --port 8000 \ + --scene 1 \ + --episodes 1 \ + --headless \ + --device cuda:1 +""" + +from __future__ import annotations + +import argparse +import json +import logging +import sys +import time +import uuid +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +import numpy as np +import torch + +try: + import mediapy +except ImportError as exc: # pragma: no cover - runtime dependency guard + raise ImportError("Optional DROID sim-eval client requires `mediapy`.") from exc + +try: + from typing import override +except ImportError: + try: + from typing_extensions import override + except ImportError as exc: # pragma: no cover - runtime dependency guard + raise ImportError("Optional DROID sim-eval client requires `typing-extensions` on Python < 3.12.") from exc + +try: + import websockets.sync.client +except ImportError as exc: # pragma: no cover - runtime dependency guard + raise ImportError("Optional DROID sim-eval client requires `websockets`.") from exc + +# NOTE: +# This directory already contains a local file named `openpi_client.py`. +# However, what we want here is the *installed* `openpi_client` package from +# upstream OpenPI, not the sibling example file. When a script is executed +# directly, Python often puts the script directory at `sys.path[0]`, which +# would cause `import openpi_client` to resolve to the local example file and +# create a circular import. +# +# To avoid that ambiguity, temporarily remove the current example directory +# from the front of `sys.path`, import the real package, and then restore the +# path afterwards. +example_dir = str(Path(__file__).resolve().parent) +removed_path = False +if sys.path and sys.path[0] == example_dir: + sys.path.pop(0) + removed_path = True +try: + from openpi_client import image_tools, msgpack_numpy + from openpi_client.base_policy import BasePolicy +except ImportError as exc: # pragma: no cover - runtime dependency guard + raise ImportError("Optional DROID sim-eval client requires the `openpi-client` package.") from exc +finally: + if removed_path: + sys.path.insert(0, example_dir) + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- +# +# These values intentionally mirror the upstream DreamZero sim-eval client +# where possible. The important distinction is: +# +# - ACTION_HORIZON = 24 +# The model returns 24 future actions per inference call. +# - DEFAULT_OPEN_LOOP_HORIZON = 8 +# The sim client only executes the first 8 actions locally before asking +# the server to replan from a fresh observation. +# +# So a single server call predicts 24x8 actions, but the rollout consumes only +# 8 of them before replanning. +PING_INTERVAL_SECS = 60 +PING_TIMEOUT_SECS = 600 +DEFAULT_PATH = "/v1/realtime/robot/openpi" +DEFAULT_OPEN_LOOP_HORIZON = 8 +ACTION_HORIZON = 24 +ACTION_DIM = 8 +DEFAULT_OUTPUT_ROOT = Path("runs") / "dreamzero_sim_eval" +SCENE_PROMPTS = { + 1: "put the cube in the bowl", + 2: "pick up the can and put it in the mug", + 3: "put the banana in the bin", +} + + +def _decode_action_response(response: bytes | str) -> np.ndarray: + if isinstance(response, str): + raise RuntimeError(f"Error in inference server:\n{response}") + decoded = msgpack_numpy.unpackb(response) + if isinstance(decoded, dict) and decoded.get("type") == "error": + message = decoded.get("message", decoded) + raise RuntimeError(f"Error in inference server:\n{message}") + return np.asarray(decoded, dtype=np.float32) + + +@dataclass(frozen=True) +class StepRecord: + """One fully materialized rollout step for later JSON export. + + The `episode_00.json` artifact is intended to be human-readable and + post-process-friendly. Instead of keeping raw tensors around, each step is + flattened into plain Python types so it can be serialized directly. + """ + + # Index of this control step within the episode. + step_index: int + # Whether this step triggered a fresh model call (as opposed to reusing the + # cached open-loop chunk from the previous server response). + used_server_call: bool + # End-to-end latency of the server call that produced the current chunk. + # This is `None` on steps that only reuse cached actions. + chunk_latency_s: float | None + # The concrete 8-D action sent into the simulator at this step. + action: list[float] + # Observed 7-DoF arm joint positions before the next environment step. + joint_position: list[float] + # Observed gripper scalar before the next environment step. + gripper_position: list[float] + # Reward and termination signals directly returned by the simulator. + reward: float + terminated: bool + truncated: bool + # Optional scene object positions for downstream debugging / success + # heuristics. This may be empty if the environment does not expose them. + scene_objects: dict[str, list[float]] + + +class OpenPIWebsocketClientPolicy(BasePolicy): + """Minimal websocket client for the DreamZero/OpenPI policy protocol. + + Protocol shape: + - connect -> server immediately sends a metadata dict + - infer -> send msgpack observation, receive action chunk + - reset -> send msgpack reset command, receive confirmation string + + This class intentionally stays very small because the more interesting + DreamZero-specific behavior lives one layer above, in + `DreamZeroJointPosClient`. + """ + + def __init__( + self, + host: str = "127.0.0.1", + port: int = 8000, + path: str = DEFAULT_PATH, + ) -> None: + # vLLM serves the robot endpoint under `/v1/realtime/robot/openpi`. + self._uri = f"ws://{host}:{port}{path}" + # Upstream protocol uses msgpack with numpy support, not JSON. + self._packer = msgpack_numpy.Packer() + # Connect immediately and cache the server handshake metadata. + self._ws, self._server_metadata = self._wait_for_server() + + def get_server_metadata(self) -> dict[str, Any]: + """Return a copy of the server handshake metadata.""" + + return dict(self._server_metadata) + + def _wait_for_server(self): + """Connect to the websocket server and read the initial metadata frame.""" + + logging.info("Connecting to %s", self._uri) + conn = websockets.sync.client.connect( + self._uri, + compression=None, + max_size=None, + ping_interval=PING_INTERVAL_SECS, + ping_timeout=PING_TIMEOUT_SECS, + ) + metadata = msgpack_numpy.unpackb(conn.recv()) + if not isinstance(metadata, dict): + raise TypeError(f"Expected dict metadata from server, got {type(metadata)!r}") + return conn, metadata + + @override + def infer(self, obs: dict[str, Any]) -> np.ndarray: + """Send an inference request and return the decoded action chunk.""" + + # Keep the upstream DreamZero/OpenPI convention that the request itself + # tells the server which logical endpoint is being called. + payload = dict(obs) + payload["endpoint"] = "infer" + self._ws.send(self._packer.pack(payload)) + response = self._ws.recv() + return _decode_action_response(response) + + @override + def reset(self, reset_info: dict[str, Any] | None = None) -> str: + """Tell the server to reset its session-side state.""" + + payload = dict(reset_info or {}) + payload["endpoint"] = "reset" + self._ws.send(self._packer.pack(payload)) + response = self._ws.recv() + if isinstance(response, str): + return response + decoded = msgpack_numpy.unpackb(response) + if not isinstance(decoded, dict) or decoded.get("status") != "reset successful": + raise RuntimeError(f"Unexpected reset response: {decoded!r}") + return str(decoded["status"]) + + def close(self) -> None: + """Close the websocket connection explicitly.""" + + self._ws.close() + + +class DreamZeroJointPosClient: + """DROID sim-eval client that talks to the vLLM OpenPI websocket server. + + This is the main compatibility layer between: + - Isaac Lab DROID observations (`obs["policy"][...]`) + - DreamZero/OpenPI websocket payloads + - local open-loop action reuse across several simulator steps + + In other words: + simulator obs -> websocket request -> action chunk -> one action per step + """ + + def __init__( + self, + remote_host: str = "127.0.0.1", + remote_port: int = 8000, + path: str = DEFAULT_PATH, + open_loop_horizon: int = DEFAULT_OPEN_LOOP_HORIZON, + ) -> None: + # Low-level transport client. + self.client = OpenPIWebsocketClientPolicy(remote_host, remote_port, path=path) + # Number of actions to execute locally before replanning. + self.open_loop_horizon = open_loop_horizon + # Cursor into the currently cached action chunk. + self.actions_from_chunk_completed = 0 + # Most recent `(ACTION_HORIZON, ACTION_DIM)` server response. + self.pred_action_chunk: np.ndarray | None = None + # Session id is part of the DreamZero serving contract. Changing it + # causes the server side to treat the rollout as a fresh episode. + self.session_id = str(uuid.uuid4()) + # Simple runtime stats for reporting. + self.server_calls = 0 + self.last_chunk_latency_s: float | None = None + self.last_used_server_call = False + + def metadata(self) -> dict[str, Any]: + """Expose the server metadata to callers / logs.""" + + return self.client.get_server_metadata() + + def reset(self) -> str: + """Reset local chunk state and remote session state. + + Local reset: + - drop cached action chunk + - rewind chunk cursor + - allocate a fresh session id + + Remote reset: + - send a websocket `reset` message so the server can clear any + request/session-side state it associates with this client + """ + + self.actions_from_chunk_completed = 0 + self.pred_action_chunk = None + self.session_id = str(uuid.uuid4()) + self.last_chunk_latency_s = None + self.last_used_server_call = False + return self.client.reset({}) + + def infer(self, obs: dict[str, Any], instruction: str) -> dict[str, Any]: + """Turn one simulator observation into one executable 8-D action. + + Key behavior: + - call the server only when the local chunk cache is empty/exhausted + - otherwise, keep consuming the cached chunk open-loop + - always return exactly one 8-D action for the current simulator step + """ + + # Convert Isaac Lab observation structure into a plain numpy-friendly + # record that is easier to serialize and visualize. + curr_obs = self._extract_observation(obs) + self.last_used_server_call = False + + # Replan if: + # 1. this is the first step of a rollout / chunk + # 2. we already consumed `open_loop_horizon` actions from the current chunk + # 3. no cached chunk is currently available + if ( + self.actions_from_chunk_completed == 0 + or self.actions_from_chunk_completed >= self.open_loop_horizon + or self.pred_action_chunk is None + ): + self.actions_from_chunk_completed = 0 + # Build the exact DreamZero/OpenPI payload expected by the server. + # + # Notes: + # - images are resized/padded to the serving contract's 180x320 + # - proprio is cast to float64 to match upstream client behavior + # - cartesian_position is currently unused by DreamZero DROID, so + # a dummy zero vector is sent for protocol completeness + request_data = { + "observation/exterior_image_0_left": image_tools.resize_with_pad(curr_obs["right_image"], 180, 320), + "observation/exterior_image_1_left": image_tools.resize_with_pad(curr_obs["left_image"], 180, 320), + "observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 180, 320), + "observation/joint_position": curr_obs["joint_position"].astype(np.float64), + "observation/cartesian_position": np.zeros((6,), dtype=np.float64), + "observation/gripper_position": curr_obs["gripper_position"].astype(np.float64), + "prompt": instruction, + "session_id": self.session_id, + } + + # Measure end-to-end server latency for this chunk request. + start = time.perf_counter() + actions = self.client.infer(request_data) + self.last_chunk_latency_s = time.perf_counter() - start + self.last_used_server_call = True + self.server_calls += 1 + + # DreamZero DROID serving is expected to return an action chunk with + # 24 future actions, each action being 8-D. + if actions.ndim != 2: + raise AssertionError(f"Expected 2D action array, got shape {actions.shape}") + if actions.shape != (ACTION_HORIZON, ACTION_DIM): + raise AssertionError(f"Expected action shape {(ACTION_HORIZON, ACTION_DIM)}, got {actions.shape}") + self.pred_action_chunk = actions + + # Consume exactly one action row from the cached chunk for this + # simulator step. + action = np.array(self.pred_action_chunk[self.actions_from_chunk_completed], copy=True) + self.actions_from_chunk_completed += 1 + + # Upstream DreamZero sim-eval binarizes the gripper command. + action[-1] = 1.0 if action[-1].item() > 0.5 else 0.0 + + # Produce a human-friendly visualization strip for videos: + # right external | wrist | left external + img1 = image_tools.resize_with_pad(curr_obs["right_image"], 224, 224) + img2 = image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224) + img3 = image_tools.resize_with_pad(curr_obs["left_image"], 224, 224) + viz = np.concatenate([img1, img2, img3], axis=1) + + # Return both the executable action and auxiliary debug info. + return { + "action": action, + "viz": viz, + "joint_position": curr_obs["joint_position"], + "gripper_position": curr_obs["gripper_position"], + "used_server_call": self.last_used_server_call, + "chunk_latency_s": self.last_chunk_latency_s if self.last_used_server_call else None, + } + + @staticmethod + def _extract_observation(obs_dict: dict[str, Any]) -> dict[str, np.ndarray]: + """Extract the pieces DreamZero cares about from Isaac Lab observations. + + `sim-evals` exposes camera frames and robot state inside the + `obs["policy"]` group. This helper converts those tensors into numpy + arrays so they can be fed into image preprocessing / websocket packing. + """ + + policy = obs_dict["policy"] + # Isaac Lab stores camera observations as batched tensors; use env 0. + right_image = policy["external_cam"][0].clone().detach().cpu().numpy() + left_image = policy["external_cam_2"][0].clone().detach().cpu().numpy() + wrist_image = policy["wrist_cam"][0].clone().detach().cpu().numpy() + # Robot proprioception. + joint_position = policy["arm_joint_pos"].clone().detach().cpu().numpy() + gripper_position = policy["gripper_pos"].clone().detach().cpu().numpy() + + return { + "right_image": right_image, + "left_image": left_image, + "wrist_image": wrist_image, + "joint_position": joint_position, + "gripper_position": gripper_position, + } + + +def _scene_instruction(scene: int) -> str: + """Map a numeric scene id onto the fixed language prompt used for rollout.""" + + try: + return SCENE_PROMPTS[scene] + except KeyError as exc: + raise ValueError(f"Unsupported scene {scene}. Available scenes: {sorted(SCENE_PROMPTS)}") from exc + + +def _capture_scene_objects(env: Any) -> dict[str, list[float]]: + """Best-effort extraction of scene object root positions. + + This is only for debugging / reporting. The rollout logic itself does not + depend on these positions. + """ + + objects: dict[str, list[float]] = {} + scene = getattr(env, "scene", None) + if scene is None: + return objects + + # Skip non-task entities such as cameras, the robot, and lighting. + for name in scene.keys(): + if name in {"robot", "external_cam", "external_cam_2", "wrist_cam", "sphere_light", "scene"}: + continue + entity = scene[name] + data = getattr(entity, "data", None) + root_pos_w = getattr(data, "root_pos_w", None) + if root_pos_w is None: + continue + # Convert tensors to plain lists for JSON serialization. + value = root_pos_w[0].detach().cpu().to(torch.float32).tolist() + objects[str(name)] = [float(x) for x in value] + return objects + + +def _maybe_infer_success(scene: int, final_objects: dict[str, list[float]]) -> dict[str, Any]: + """Best-effort geometric heuristic. + + The simulator itself does not expose a built-in success term; this function + provides a transparent fallback for human-readable reporting only. + """ + + task_pairs = { + 1: ("cube", "bowl"), + 2: ("can", "mug"), + 3: ("banana", "bin"), + } + source_name, target_name = task_pairs.get(scene, (None, None)) + if source_name not in final_objects or target_name not in final_objects: + return { + "has_builtin_success": False, + "heuristic_success": None, + "reason": "scene object names unavailable for heuristic", + } + + source = np.asarray(final_objects[source_name], dtype=np.float32) + target = np.asarray(final_objects[target_name], dtype=np.float32) + xy_distance = float(np.linalg.norm(source[:2] - target[:2])) + z_delta = float(source[2] - target[2]) + heuristic_success = bool(xy_distance < 0.12 and z_delta > -0.08) + return { + "has_builtin_success": False, + "heuristic_success": heuristic_success, + "xy_distance": xy_distance, + "z_delta": z_delta, + "source_object": source_name, + "target_object": target_name, + "reason": "no built-in env success flag; using final object pose heuristic", + } + + +def _dump_json(path: Path, payload: dict[str, Any]) -> None: + """Write a JSON file with stable UTF-8 formatting.""" + + path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") + + +def _scalar_from_env_value(value: Any) -> float: + """Normalize simulator scalar outputs into a plain Python float. + + Isaac Lab / Gym values may come back as tensors, numpy arrays, tuples, or + direct Python scalars depending on the wrapper stack. Centralizing the + conversion here makes the rollout loop cleaner and more robust. + """ + + if isinstance(value, torch.Tensor): + return float(value.reshape(-1)[0].detach().cpu().item()) + if isinstance(value, np.ndarray): + return float(value.reshape(-1)[0]) + if isinstance(value, (list, tuple)): + return float(np.asarray(value).reshape(-1)[0]) + return float(value) + + +def _bool_from_env_value(value: Any) -> bool: + """Normalize simulator boolean-like outputs into a plain Python bool.""" + + if isinstance(value, torch.Tensor): + return bool(value.reshape(-1)[0].detach().cpu().item()) + if isinstance(value, np.ndarray): + return bool(value.reshape(-1)[0]) + if isinstance(value, (list, tuple)): + return bool(np.asarray(value).reshape(-1)[0]) + return bool(value) + + +def _make_output_dir(output_root: Path, scene: int) -> Path: + """Create a timestamped output directory for one sim-eval run.""" + + timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") + output_dir = output_root / f"scene{scene}_{timestamp}" + output_dir.mkdir(parents=True, exist_ok=True) + return output_dir + + +def main() -> None: + """Entry point for one or more DROID simulation rollouts. + + High-level flow: + 1. parse command-line flags + 2. bootstrap Isaac Lab / sim-evals imports + 3. create the DROID environment + 4. connect the DreamZero websocket client + 5. run `episodes` rollouts + 6. export videos + JSON summaries + """ + + # Script-level arguments. + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--episodes", type=int, default=1, help="Number of episodes to run.") + parser.add_argument("--scene", type=int, default=1, help="DROID scene id (1/2/3).") + parser.add_argument("--host", type=str, default="127.0.0.1", help="vLLM DreamZero server host.") + parser.add_argument("--port", type=int, default=8000, help="vLLM DreamZero server port.") + parser.add_argument("--path", type=str, default=DEFAULT_PATH, help="Websocket path suffix.") + parser.add_argument( + "--open-loop-horizon", + type=int, + default=DEFAULT_OPEN_LOOP_HORIZON, + help="How many actions to consume locally before requesting the next chunk.", + ) + parser.add_argument( + "--output-root", + type=Path, + default=DEFAULT_OUTPUT_ROOT, + help="Directory where videos and trajectory logs are stored.", + ) + + try: + from isaaclab.app import AppLauncher + except ImportError as exc: # pragma: no cover - runtime dependency guard + raise ImportError( + "Optional DROID sim-eval client requires Isaac Lab (`isaaclab`). " + "Launch it from an Isaac Lab environment, e.g. via `isaaclab.sh -p`." + ) from exc + + # Let Isaac Lab inject its own runtime flags (e.g. headless, device). + AppLauncher.add_app_launcher_args(parser) + args = parser.parse_args() + + # DreamZero sim-eval always needs camera observations enabled. + args.enable_cameras = True + # Boot Isaac Sim / Isaac Lab. + app_launcher = AppLauncher(args) + simulation_app = app_launcher.app + # Set defaults so the `finally` block can clean up safely even if an + # earlier step fails. + env = None + client = None + + # Import simulator modules only *after* the app is launched. This matches + # Isaac Lab's required import ordering. + try: + import gymnasium as gym + except ImportError as exc: # pragma: no cover - runtime dependency guard + raise ImportError("Optional DROID sim-eval client requires `gymnasium`.") from exc + + try: + import sim_evals.environments # noqa: F401 + except ImportError as exc: # pragma: no cover - runtime dependency guard + raise ImportError( + "Optional DROID sim-eval client requires the external `sim-evals` package or checkout to be importable." + ) from exc + + try: + from isaaclab_tasks.utils import parse_env_cfg + except ImportError as exc: # pragma: no cover - runtime dependency guard + raise ImportError("Optional DROID sim-eval client requires `isaaclab_tasks`.") from exc + + # Resolve output location and scene prompt. + output_dir = _make_output_dir(args.output_root.expanduser().resolve(), args.scene) + instruction = _scene_instruction(args.scene) + + # Build the DROID environment configuration from `sim-evals`. + env_cfg = parse_env_cfg( + "DROID", + device=args.device, + num_envs=1, + use_fabric=True, + ) + # Select one of the pre-authored scenes/tasks. + env_cfg.set_scene(args.scene) + env = gym.make("DROID", cfg=env_cfg) + + # Upstream sim-evals resets twice so materials / cameras are fully ready. + obs, _ = env.reset() + obs, _ = env.reset() + + # Connect the websocket policy client. + client = DreamZeroJointPosClient( + remote_host=args.host, + remote_port=args.port, + path=args.path, + open_loop_horizon=args.open_loop_horizon, + ) + + # Aggregated per-run results. + all_episode_summaries: list[dict[str, Any]] = [] + max_steps = int(env.env.max_episode_length) + logging.info("DreamZero metadata: %s", client.metadata()) + logging.info("Scene %s prompt: %s", args.scene, instruction) + logging.info("Writing outputs to %s", output_dir) + + try: + # No gradients are needed in inference-only rollout mode. + with torch.no_grad(): + for episode_index in range(args.episodes): + # Per-episode collectors. + frames: list[np.ndarray] = [] + step_records: list[StepRecord] = [] + episode_start = time.perf_counter() + server_time_s = 0.0 + final_reward = 0.0 + terminated = False + truncated = False + + for step_index in range(max_steps): + # Ask the policy for the next action. Internally this may or + # may not trigger a real server request depending on whether + # the local chunk cache has been exhausted. + logging.debug("Episode %d step %d: requesting action", episode_index, step_index) + result = client.infer(obs, instruction) + logging.debug( + "Episode %d step %d: got action (server_call=%s latency=%s)", + episode_index, + step_index, + result["used_server_call"], + result["chunk_latency_s"], + ) + # Save one visualization frame per simulator step. + frames.append(result["viz"]) + + # Isaac Lab expects batched actions, hence `[None]`. + action_tensor = torch.tensor(result["action"], dtype=torch.float32)[None] + logging.debug("Episode %d step %d: stepping env", episode_index, step_index) + obs, reward, term, trunc, info = env.step(action_tensor) + logging.debug("Episode %d step %d: env.step returned", episode_index, step_index) + logging.debug("Episode %d step %d: parsing reward/flags", episode_index, step_index) + logging.debug( + "Episode %d step %d: raw types reward=%s term=%s trunc=%s", + episode_index, + step_index, + type(reward).__name__, + type(term).__name__, + type(trunc).__name__, + ) + # Normalize environment outputs into plain Python scalars so + # the rest of the code does not depend on wrapper-specific types. + reward_value = _scalar_from_env_value(reward) + term_value = _bool_from_env_value(term) + trunc_value = _bool_from_env_value(trunc) + logging.debug( + "Episode %d step %d: parsed reward=%s term=%s trunc=%s", + episode_index, + step_index, + reward_value, + term_value, + trunc_value, + ) + # Keep scene-object capture optional. It is useful for + # debugging / success heuristics, but the rollout should not + # fail if the environment does not expose object roots. + scene_objects = _capture_scene_objects(env) + + # Accumulate total server-side time only on steps that + # triggered a fresh chunk inference. + if result["chunk_latency_s"] is not None: + server_time_s += float(result["chunk_latency_s"]) + + # Materialize one JSON-serializable trajectory record. + logging.debug("Episode %d step %d: appending trajectory", episode_index, step_index) + step_records.append( + StepRecord( + step_index=step_index, + used_server_call=bool(result["used_server_call"]), + chunk_latency_s=( + float(result["chunk_latency_s"]) if result["chunk_latency_s"] is not None else None + ), + action=[float(x) for x in np.asarray(result["action"], dtype=np.float32).tolist()], + joint_position=[ + float(x) for x in np.asarray(result["joint_position"], dtype=np.float32).tolist() + ], + gripper_position=[ + float(x) for x in np.asarray(result["gripper_position"], dtype=np.float32).tolist() + ], + reward=reward_value, + terminated=term_value, + truncated=trunc_value, + scene_objects=scene_objects, + ) + ) + logging.debug("Episode %d step %d: trajectory appended", episode_index, step_index) + + # Track final status for summary export. + final_reward = reward_value + terminated = term_value + truncated = trunc_value + if term_value or trunc_value: + # End the rollout early if the environment terminates. + break + + # Episode-level timing and video export. + episode_wall_time_s = time.perf_counter() - episode_start + video_path = output_dir / f"episode_{episode_index:02d}.mp4" + logging.info("Episode %d: writing video to %s", episode_index, video_path) + mediapy.write_video(video_path, frames, fps=15) + + # Reset the policy server between episodes. + logging.info("Episode %d: sending reset", episode_index) + reset_response = client.reset() + final_objects = step_records[-1].scene_objects if step_records else {} + success_report = _maybe_infer_success(args.scene, final_objects) + + # Assemble the per-episode summary that is written to + # `episode_XX.json`. + episode_summary = { + "episode_index": episode_index, + "prompt": instruction, + "video_path": str(video_path), + "steps_executed": len(step_records), + "max_steps": max_steps, + "terminated": terminated, + "truncated": truncated, + "final_reward": final_reward, + "server_calls": client.server_calls, + "server_time_s": server_time_s, + "episode_wall_time_s": episode_wall_time_s, + "avg_server_time_per_call_s": ( + server_time_s / client.server_calls if client.server_calls else None + ), + "reset_response": reset_response, + "success_report": success_report, + "server_metadata": client.metadata(), + "trajectory": [record.__dict__ for record in step_records], + } + _dump_json(output_dir / f"episode_{episode_index:02d}.json", episode_summary) + all_episode_summaries.append(episode_summary) + + logging.info( + "Episode %d done: steps=%d wall=%.2fs server_calls=%d heuristic_success=%s", + episode_index, + len(step_records), + episode_wall_time_s, + client.server_calls, + success_report.get("heuristic_success"), + ) + + # Reset per-episode counters while keeping the client alive. + client.server_calls = 0 + + # Top-level run summary across all episodes. + summary = { + "scene": args.scene, + "prompt": instruction, + "episodes": args.episodes, + "host": args.host, + "port": args.port, + "path": args.path, + "device": args.device, + "output_dir": str(output_dir), + "server_metadata": client.metadata(), + "episodes_summary": all_episode_summaries, + } + _dump_json(output_dir / "summary.json", summary) + # Also print the summary to stdout so the caller can capture it in logs. + print(json.dumps(summary, ensure_ascii=False, indent=2)) + finally: + # Best-effort cleanup. Avoid masking the main error if cleanup fails. + try: + if client is not None: + client.client.close() + except Exception: + pass + if env is not None: + env.close() + simulation_app.close() + + +if __name__ == "__main__": + # Keep the script-level logs readable. Per-step rollout details are still + # available via `DEBUG` if needed, but websocket / asyncio internals are + # usually too noisy for normal usage. + logging.basicConfig(level=logging.INFO) + logging.getLogger("websockets").setLevel(logging.WARNING) + logging.getLogger("asyncio").setLevel(logging.WARNING) + main() diff --git a/examples/online_serving/dreamzero/export_prediction_video.py b/examples/online_serving/dreamzero/export_prediction_video.py new file mode 100644 index 00000000000..2fa59ae9e6c --- /dev/null +++ b/examples/online_serving/dreamzero/export_prediction_video.py @@ -0,0 +1,301 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import argparse +import uuid +from pathlib import Path + +import cv2 +import numpy as np +import torch +from PIL import Image + +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput + +WORKER_EXTENSION = "vllm_omni.diffusion.models.dreamzero.video_export_worker.DreamZeroVideoExportWorkerExtension" +DEFAULT_MODEL = "GEAR-Dreams/DreamZero-DROID" +DEFAULT_PROMPT = "Move the pan forward and use the brush in the middle of the plates to brush the inside of the pan" +REPO_ROOT = Path(__file__).resolve().parents[3] +ASSET_REPO_ID = "YangshenDeng/vllm-omni-dreamzero-assets" +DEFAULT_VIDEO_DIR = REPO_ROOT / "outputs" / "dreamzero" / "assets" +DEFAULT_OUTPUT_DIR = REPO_ROOT / "outputs" / "dreamzero" / "generated_predictions" +DEFAULT_OUTPUT_STEM = "dreamzero_prediction" +DEFAULT_SESSION_PREFIX = "dreamzero-export" +RELATIVE_OFFSETS = [-23, -16, -8, 0] +ACTION_HORIZON = 24 +CAMERA_FILES = { + "observation/exterior_image_0_left": "exterior_image_1_left.mp4", + "observation/exterior_image_1_left": "exterior_image_2_left.mp4", + "observation/wrist_image_left": "wrist_image_left.mp4", +} + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Export DreamZero prediction video from downloaded example inputs.") + parser.add_argument("--model", default=DEFAULT_MODEL) + parser.add_argument("--deploy-config", type=Path, required=True) + parser.add_argument( + "--video-dir", type=Path, default=DEFAULT_VIDEO_DIR, help="Directory containing the three camera MP4 files." + ) + parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR) + parser.add_argument("--output-stem", default=DEFAULT_OUTPUT_STEM) + parser.add_argument("--prompt", default=DEFAULT_PROMPT) + parser.add_argument("--session-id", default=None) + parser.add_argument("--save-input-video", action="store_true") + parser.add_argument("--save-gif", action="store_true") + parser.add_argument("--save-actions", action="store_true") + parser.add_argument("--fps", type=int, default=5) + return parser.parse_args() + + +def _load_all_frames(video_path: Path) -> np.ndarray: + cap = cv2.VideoCapture(str(video_path)) + frames = [] + while True: + ok, frame = cap.read() + if not ok: + break + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + cap.release() + if not frames: + raise RuntimeError(f"No frames loaded from {video_path}") + return np.stack(frames, axis=0) + + +def _load_camera_frames(video_dir: Path) -> dict[str, np.ndarray]: + camera_frames: dict[str, np.ndarray] = {} + for camera_key, file_name in CAMERA_FILES.items(): + video_path = video_dir / file_name + if not video_path.exists(): + raise FileNotFoundError( + f"Missing DreamZero example asset: {video_path}. " + "Download the example videos with: " + f"`hf download {ASSET_REPO_ID} --repo-type dataset --local-dir {video_dir}`" + ) + camera_frames[camera_key] = _load_all_frames(video_path) + return camera_frames + + +def _build_frame_schedule(total_frames: int, num_chunks: int) -> list[list[int]]: + chunks: list[list[int]] = [] + current_frame = 23 + for _ in range(num_chunks): + indices = [max(current_frame + offset, 0) for offset in RELATIVE_OFFSETS] + if indices[-1] >= total_frames: + break + chunks.append(indices) + current_frame += ACTION_HORIZON + return chunks + + +def _make_obs_from_video( + camera_frames: dict[str, np.ndarray], + frame_indices: list[int], + *, + prompt: str, + session_id: str, +) -> dict: + obs: dict = {} + for camera_key, all_frames in camera_frames.items(): + selected = all_frames[frame_indices] + obs[camera_key] = selected[0] if len(frame_indices) == 1 else selected + + obs["observation/joint_position"] = np.zeros(7, dtype=np.float32) + obs["observation/cartesian_position"] = np.zeros(6, dtype=np.float32) + obs["observation/gripper_position"] = np.zeros(1, dtype=np.float32) + obs["prompt"] = prompt + obs["session_id"] = session_id + return obs + + +def _build_observations(video_dir: Path, prompt: str, session_id: str) -> tuple[dict[str, np.ndarray], list[dict]]: + camera_frames = _load_camera_frames(video_dir) + total_frames = min(frames.shape[0] for frames in camera_frames.values()) + chunks = _build_frame_schedule(total_frames, 1) + observations = [ + _make_obs_from_video(camera_frames, [0], prompt=prompt, session_id=session_id), + ] + if chunks: + observations.append( + _make_obs_from_video( + camera_frames, + chunks[0], + prompt=prompt, + session_id=session_id, + ) + ) + if len(observations) < 2: + raise RuntimeError("Need at least two DreamZero example observations to export a prediction video.") + return camera_frames, observations[:2] + + +def _extract_latents(output: OmniRequestOutput) -> torch.Tensor: + if not isinstance(output, OmniRequestOutput): + raise TypeError(f"Expected OmniRequestOutput, got {type(output)!r}") + if not output.images: + raise RuntimeError("DreamZero output does not contain video latents in `images`.") + + latents = output.images[0] + if not isinstance(latents, torch.Tensor): + raise TypeError(f"Expected tensor latents, got {type(latents)!r}") + + latents = latents.detach().cpu() + if latents.dim() == 4: + latents = latents.unsqueeze(0) + if latents.dim() != 5: + raise ValueError(f"Unexpected latent shape: {tuple(latents.shape)}") + + if latents.shape[1] < latents.shape[2]: + latents = latents.transpose(1, 2).contiguous() + return latents + + +def _write_mp4(path: Path, frames: np.ndarray, fps: int) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + height, width = frames.shape[1:3] + writer = cv2.VideoWriter( + str(path), + cv2.VideoWriter_fourcc(*"mp4v"), + float(fps), + (width, height), + ) + if not writer.isOpened(): + raise RuntimeError(f"Failed to open video writer for {path}") + try: + for frame in frames: + writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) + finally: + writer.release() + + +def _write_gif(path: Path, frames: np.ndarray, fps: int) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + images = [Image.fromarray(frame) for frame in frames] + duration_ms = max(int(round(1000 / max(fps, 1))), 1) + images[0].save( + path, + save_all=True, + append_images=images[1:], + duration=duration_ms, + loop=0, + ) + + +def _stitch_input_frames(camera_frames: dict[str, np.ndarray]) -> np.ndarray: + total_frames = min(frames.shape[0] for frames in camera_frames.values()) + stitched = [] + for frame_index in range(total_frames): + left = camera_frames["observation/exterior_image_0_left"][frame_index] + right = camera_frames["observation/exterior_image_1_left"][frame_index] + wrist = camera_frames["observation/wrist_image_left"][frame_index] + pad = np.zeros((left.shape[0], left.shape[1], 3), dtype=np.uint8) + canvas = np.concatenate([left, right], axis=1) + bottom = np.concatenate([wrist, pad], axis=1) + stitched.append(np.concatenate([canvas, bottom], axis=0)) + return np.stack(stitched, axis=0) + + +def _run_generation( + model: str, deploy_config_path: Path, observations: list[dict] +) -> tuple[Omni, list[OmniRequestOutput]]: + omni = Omni( + model=model, + deploy_config=str(deploy_config_path), + enforce_eager=True, + worker_extension_cls=WORKER_EXTENSION, + ) + + outputs: list[OmniRequestOutput] = [] + for index, obs in enumerate(observations): + sampling_params = OmniDiffusionSamplingParams( + extra_args={ + "reset": index == 0, + "session_id": obs["session_id"], + "robot_obs": obs, + } + ) + result = omni.generate(obs["prompt"], sampling_params_list=[sampling_params]) + if not result: + raise RuntimeError(f"No output returned for DreamZero request {index}") + outputs.append(result[0]) + return omni, outputs + + +def _decode_with_worker(omni: Omni, full_latents: torch.Tensor) -> np.ndarray: + stage_client = omni.engine.stage_clients[0] + engine = getattr(stage_client, "_engine", None) + if engine is None: + raise RuntimeError("DreamZero export requires inline diffusion stage access.") + + decoded = engine.executor.collective_rpc( + "decode_video_latents_to_uint8", + args=(full_latents,), + unique_reply_rank=0, + exec_all_ranks=True, + ) + if isinstance(decoded, torch.Tensor): + decoded = decoded.numpy() + if not isinstance(decoded, np.ndarray): + raise TypeError(f"Unexpected decoded output type: {type(decoded)!r}") + return decoded + + +def main() -> None: + args = _parse_args() + session_id = args.session_id or f"{DEFAULT_SESSION_PREFIX}-{uuid.uuid4()}" + + camera_frames, observations = _build_observations( + video_dir=args.video_dir, + prompt=args.prompt, + session_id=session_id, + ) + + args.output_dir.mkdir(parents=True, exist_ok=True) + + if args.save_input_video: + input_frames = _stitch_input_frames(camera_frames) + _write_mp4(args.output_dir / f"{args.output_stem}_input.mp4", input_frames, fps=15) + if args.save_gif: + _write_gif(args.output_dir / f"{args.output_stem}_input.gif", input_frames[::3], fps=5) + + omni = None + try: + omni, outputs = _run_generation( + model=args.model, + deploy_config_path=args.deploy_config, + observations=observations, + ) + latent_steps = [_extract_latents(output) for output in outputs] + full_latents = torch.cat(latent_steps, dim=2) + frames = _decode_with_worker(omni, full_latents) + finally: + if omni is not None: + omni.close() + + mp4_path = args.output_dir / f"{args.output_stem}.mp4" + + _write_mp4(mp4_path, frames, fps=args.fps) + print(f"SAVED_MP4={mp4_path}") + + if args.save_gif: + gif_path = args.output_dir / f"{args.output_stem}.gif" + _write_gif(gif_path, frames, fps=args.fps) + print(f"SAVED_GIF={gif_path}") + + if args.save_actions: + npz_path = args.output_dir / f"{args.output_stem}_actions.npz" + np.savez( + npz_path, + step0=np.asarray(outputs[0].multimodal_output.get("actions")), + step1=np.asarray(outputs[1].multimodal_output.get("actions")), + ) + print(f"SAVED_ACTIONS={npz_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/dreamzero/molmospace_dreamzero_eval_demo.py b/examples/online_serving/dreamzero/molmospace_dreamzero_eval_demo.py new file mode 100644 index 00000000000..11f5dc68b86 --- /dev/null +++ b/examples/online_serving/dreamzero/molmospace_dreamzero_eval_demo.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + +os.environ.setdefault("MUJOCO_GL", "egl") +os.environ.setdefault("PYOPENGL_PLATFORM", "egl") + +_DEMO_HOST = os.environ.get("VLLM_OMNI_DEMO_HOST", "127.0.0.1") +_DEMO_PORT = int(os.environ.get("VLLM_OMNI_DEMO_PORT", "8000")) + +# Import base configs at module top level so the subclasses below are pickle- +# resolvable (worker processes import this module fresh via __main__). +from molmo_spaces.configs.policy_configs_baselines import ( # noqa: E402 + DreamZeroPolicyConfig, +) +from molmo_spaces.evaluation.configs.evaluation_configs import ( # noqa: E402 + DreamZeroPolicyEvalConfig, +) + + +# We only need to change the backend host and port to the vllm-host! +class DreamZeroVllmOmniPolicyConfig(DreamZeroPolicyConfig): + remote_config: dict = dict(host=_DEMO_HOST, port=_DEMO_PORT) + + +class DreamZeroVllmOmniEvalConfig(DreamZeroPolicyEvalConfig): + policy_config: DreamZeroVllmOmniPolicyConfig = DreamZeroVllmOmniPolicyConfig() + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--benchmark_dir", + required=True, + help=( + "Path to a MolmoSpaces benchmark directory, for example " + "$MOLMOSPACES_BENCHMARK_DIR/20260327/ithor/FrankaCloseHardBench/" + "FrankaCloseHardBench_20260206_json_benchmark" + ), + ) + parser.add_argument("--max_episodes", type=int, default=1) + parser.add_argument("--task_horizon_steps", type=int, default=80) + parser.add_argument( + "--output_dir", + required=True, + help="Directory to write evaluation outputs (created if missing).", + ) + parser.add_argument("--episode_idx", type=int, default=None) + args = parser.parse_args() + + os.environ["VLLM_OMNI_DEMO_HOST"] = args.host + os.environ["VLLM_OMNI_DEMO_PORT"] = str(args.port) + DreamZeroVllmOmniPolicyConfig.model_fields["remote_config"].default = dict(host=args.host, port=args.port) + + # Import after env vars are set so MuJoCo picks EGL. + from molmo_spaces.evaluation import run_evaluation + + cfg_cls = DreamZeroVllmOmniEvalConfig + + output_dir = args.output_dir + Path(output_dir).mkdir(parents=True, exist_ok=True) + + print(f"[eval] benchmark_dir={args.benchmark_dir}") + print(f"[eval] max_episodes={args.max_episodes} task_horizon_steps={args.task_horizon_steps}") + print(f"[eval] remote policy: ws://{args.host}:{args.port}/v1/realtime/robot/openpi") + + results = run_evaluation( + eval_config_cls=cfg_cls, + benchmark_dir=Path(args.benchmark_dir), + max_episodes=args.max_episodes, + task_horizon_steps=args.task_horizon_steps, + num_workers=1, + use_wandb=False, + output_dir=output_dir, + episode_idx=args.episode_idx, + ) + + print(f"[eval] success={results.success_count}/{results.total_count} ({results.success_rate:.1%})") + print(f"[eval] output_dir={results.output_dir}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/online_serving/dreamzero/openpi_client.py b/examples/online_serving/dreamzero/openpi_client.py new file mode 100755 index 00000000000..114817a9b18 --- /dev/null +++ b/examples/online_serving/dreamzero/openpi_client.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import argparse +import json +import logging +import sys +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np + +try: + import cv2 +except ImportError as exc: # pragma: no cover - runtime dependency guard + raise ImportError("DreamZero OpenPI example requires `opencv-python`.") from exc + +try: + import websockets.sync.client +except ImportError as exc: # pragma: no cover - runtime dependency guard + raise ImportError("DreamZero OpenPI example requires `websockets`.") from exc + +try: + example_dir = str(Path(__file__).resolve().parent) + removed_path = False + if sys.path and sys.path[0] == example_dir: + sys.path.pop(0) + removed_path = True + try: + from openpi_client import msgpack_numpy + finally: + if removed_path: + sys.path.insert(0, example_dir) +except ImportError as exc: # pragma: no cover - runtime dependency guard + raise ImportError("DreamZero OpenPI example requires `openpi-client`.") from exc + +PING_INTERVAL_SECS = 300 +PING_TIMEOUT_SECS = 3600 +DEFAULT_HOST = "127.0.0.1" +DEFAULT_PORT = 8000 +DEFAULT_PATH = "/v1/realtime/robot/openpi" +DEFAULT_PROMPT = "Move the pan forward and use the brush in the middle of the plates to brush the inside of the pan" +ACTION_HORIZON = 24 +DEFAULT_ACTION_DIM = 8 +RELATIVE_OFFSETS = [-23, -16, -8, 0] +REPO_ROOT = Path(__file__).resolve().parents[3] +ASSET_REPO_ID = "YangshenDeng/vllm-omni-dreamzero-assets" +DEFAULT_VIDEO_DIR = REPO_ROOT / "outputs" / "dreamzero" / "assets" +CAMERA_FILES = { + "observation/exterior_image_0_left": "exterior_image_1_left.mp4", + "observation/exterior_image_1_left": "exterior_image_2_left.mp4", + "observation/wrist_image_left": "wrist_image_left.mp4", +} + + +def _decode_action_response(response: bytes | str) -> np.ndarray: + if isinstance(response, str): + raise RuntimeError(f"Inference failed: {response}") + decoded = msgpack_numpy.unpackb(response) + if isinstance(decoded, dict) and decoded.get("type") == "error": + message = decoded.get("message", decoded) + raise RuntimeError(f"Inference failed: {message}") + return np.asarray(decoded, dtype=np.float32) + + +@dataclass(frozen=True) +class DreamZeroServerMetadata: + image_resolution: tuple[int, int] + n_external_cameras: int + needs_wrist_camera: bool + needs_stereo_camera: bool + needs_session_id: bool + action_space: str + + @classmethod + def from_dict(cls, payload: dict[str, Any]) -> DreamZeroServerMetadata: + required_keys = ( + "image_resolution", + "n_external_cameras", + "needs_wrist_camera", + "needs_stereo_camera", + "needs_session_id", + "action_space", + ) + missing_keys = [key for key in required_keys if key not in payload] + if missing_keys: + raise ValueError(f"Missing DreamZero metadata keys: {missing_keys}") + + image_resolution = payload["image_resolution"] + if not isinstance(image_resolution, (list, tuple)) or len(image_resolution) != 2: + raise ValueError(f"Invalid image_resolution: {image_resolution!r}") + + return cls( + image_resolution=(int(image_resolution[0]), int(image_resolution[1])), + n_external_cameras=int(payload["n_external_cameras"]), + needs_wrist_camera=bool(payload["needs_wrist_camera"]), + needs_stereo_camera=bool(payload["needs_stereo_camera"]), + needs_session_id=bool(payload["needs_session_id"]), + action_space=str(payload["action_space"]), + ) + + +class OpenPIWebsocketClient: + def __init__( + self, + *, + host: str = DEFAULT_HOST, + port: int = DEFAULT_PORT, + path: str = DEFAULT_PATH, + ) -> None: + self._uri = f"ws://{host}:{port}{path}" + self._packer = msgpack_numpy.Packer() + self._ws, self._server_metadata = self._connect() + + def _connect(self): + logging.info("Connecting to %s", self._uri) + conn = websockets.sync.client.connect( + self._uri, + compression=None, + max_size=None, + ping_interval=PING_INTERVAL_SECS, + ping_timeout=PING_TIMEOUT_SECS, + ) + metadata = msgpack_numpy.unpackb(conn.recv()) + if not isinstance(metadata, dict): + raise TypeError(f"Expected dict metadata from server, got {type(metadata)!r}") + return conn, metadata + + def get_server_metadata(self) -> dict[str, Any]: + return dict(self._server_metadata) + + def infer(self, obs: dict[str, Any]) -> np.ndarray: + payload = dict(obs) + payload["endpoint"] = "infer" + self._ws.send(self._packer.pack(payload)) + response = self._ws.recv() + return _decode_action_response(response) + + def reset(self, reset_info: dict[str, Any] | None = None) -> str: + payload = dict(reset_info or {}) + payload["endpoint"] = "reset" + self._ws.send(self._packer.pack(payload)) + response = self._ws.recv() + if isinstance(response, str): + return response + decoded = msgpack_numpy.unpackb(response) + if not isinstance(decoded, dict) or decoded.get("status") != "reset successful": + raise RuntimeError(f"Unexpected reset response: {decoded!r}") + return str(decoded["status"]) + + def close(self) -> None: + self._ws.close() + + +def load_all_frames(video_path: Path) -> np.ndarray: + cap = cv2.VideoCapture(str(video_path)) + frames = [] + while True: + ok, frame = cap.read() + if not ok: + break + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + cap.release() + if not frames: + raise RuntimeError(f"No frames loaded from {video_path}") + return np.stack(frames, axis=0) + + +def load_camera_frames(video_dir: Path) -> dict[str, np.ndarray]: + camera_frames: dict[str, np.ndarray] = {} + for camera_key, file_name in CAMERA_FILES.items(): + video_path = video_dir / file_name + if not video_path.exists(): + raise FileNotFoundError( + f"Missing DreamZero example asset: {video_path}. " + "Download the example videos with: " + f"`hf download {ASSET_REPO_ID} --repo-type dataset --local-dir {video_dir}`" + ) + camera_frames[camera_key] = load_all_frames(video_path) + return camera_frames + + +def build_frame_schedule(total_frames: int, num_chunks: int) -> list[list[int]]: + chunks: list[list[int]] = [] + current_frame = 23 + for _ in range(num_chunks): + indices = [max(current_frame + offset, 0) for offset in RELATIVE_OFFSETS] + if indices[-1] >= total_frames: + break + chunks.append(indices) + current_frame += ACTION_HORIZON + return chunks + + +def make_obs_from_video( + camera_frames: dict[str, np.ndarray], + frame_indices: list[int], + *, + prompt: str, + session_id: str, +) -> dict[str, Any]: + obs: dict[str, Any] = {} + for camera_key, all_frames in camera_frames.items(): + selected = all_frames[frame_indices] + obs[camera_key] = selected[0] if len(frame_indices) == 1 else selected + + obs["observation/joint_position"] = np.zeros(7, dtype=np.float32) + obs["observation/cartesian_position"] = np.zeros(6, dtype=np.float32) + obs["observation/gripper_position"] = np.zeros(1, dtype=np.float32) + obs["prompt"] = prompt + obs["session_id"] = session_id + return obs + + +def build_demo_observations( + camera_frames: dict[str, np.ndarray], + *, + prompt: str, + session_id: str, + num_chunks: int = 2, +) -> list[dict[str, Any]]: + if num_chunks < 1: + raise ValueError("num_chunks must be at least 1") + + total_frames = min(frames.shape[0] for frames in camera_frames.values()) + observations = [ + make_obs_from_video( + camera_frames, + [0], + prompt=prompt, + session_id=session_id, + ) + ] + for indices in build_frame_schedule(total_frames, num_chunks - 1): + observations.append( + make_obs_from_video( + camera_frames, + indices, + prompt=prompt, + session_id=session_id, + ) + ) + return observations + + +def validate_session_result( + result: dict[str, Any], + *, + expected_action_horizon: int = ACTION_HORIZON, + expected_action_dim: int = DEFAULT_ACTION_DIM, +) -> None: + metadata = DreamZeroServerMetadata.from_dict(result["metadata"]) + if metadata.image_resolution != (180, 320): + raise AssertionError(f"Unexpected image_resolution: {metadata.image_resolution}") + if metadata.n_external_cameras != 2: + raise AssertionError(f"Unexpected n_external_cameras: {metadata.n_external_cameras}") + if not metadata.needs_wrist_camera: + raise AssertionError("DreamZero example expects wrist camera metadata") + if metadata.action_space != "joint_position": + raise AssertionError(f"Unexpected action_space: {metadata.action_space}") + + actions = result["actions"] + if len(actions) != 3: + raise AssertionError(f"Expected 3 action tensors, got {len(actions)}") + for index, action in enumerate(actions): + if action.shape != (expected_action_horizon, expected_action_dim): + raise AssertionError( + f"Action {index} shape mismatch: expected " + f"{(expected_action_horizon, expected_action_dim)}, got {action.shape}" + ) + if not np.isfinite(action).all(): + raise AssertionError(f"Action {index} contains non-finite values") + + if result["reset_status"] != "reset successful": + raise AssertionError(f"Unexpected reset status: {result['reset_status']!r}") + + +def run_policy_session( + *, + host: str = DEFAULT_HOST, + port: int = DEFAULT_PORT, + path: str = DEFAULT_PATH, + video_dir: Path = DEFAULT_VIDEO_DIR, + prompt: str = DEFAULT_PROMPT, + session_id: str | None = None, + num_chunks: int = 2, +) -> dict[str, Any]: + session_id = session_id or str(uuid.uuid4()) + camera_frames = load_camera_frames(video_dir) + observations = build_demo_observations( + camera_frames, + prompt=prompt, + session_id=session_id, + num_chunks=num_chunks, + ) + + client = OpenPIWebsocketClient(host=host, port=port, path=path) + try: + metadata = client.get_server_metadata() + actions = [client.infer(obs) for obs in observations] + reset_status = client.reset({}) + actions.append(client.infer(observations[0])) + return { + "metadata": metadata, + "actions": actions, + "reset_status": reset_status, + "session_id": session_id, + } + finally: + client.close() + + +def format_action_summary(index: int, action: np.ndarray) -> str: + return ( + f"Action {index}: shape={tuple(action.shape)} dtype={action.dtype} " + f"min={action.min():.6f} max={action.max():.6f}" + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="DreamZero OpenPI client example with downloaded real videos.") + parser.add_argument("--host", default=DEFAULT_HOST) + parser.add_argument("--port", type=int, default=DEFAULT_PORT) + parser.add_argument("--path", default=DEFAULT_PATH) + parser.add_argument( + "--video-dir", type=Path, default=DEFAULT_VIDEO_DIR, help="Directory containing the three camera MP4 files." + ) + parser.add_argument("--prompt", default=DEFAULT_PROMPT) + parser.add_argument("--session-id", default=None) + parser.add_argument("--num-chunks", type=int, default=2) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + logging.basicConfig(level=logging.INFO) + + result = run_policy_session( + host=args.host, + port=args.port, + path=args.path, + video_dir=args.video_dir, + prompt=args.prompt, + session_id=args.session_id, + num_chunks=args.num_chunks, + ) + validate_session_result(result) + + print("Server metadata:", json.dumps(result["metadata"], sort_keys=True)) + for index, action in enumerate(result["actions"]): + print(format_action_summary(index, action)) + print("Reset status:", result["reset_status"]) + print("Session ID:", result["session_id"]) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/online_serving/dreamzero/run_server.sh b/examples/online_serving/dreamzero/run_server.sh new file mode 100755 index 00000000000..c28e0aee5da --- /dev/null +++ b/examples/online_serving/dreamzero/run_server.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +set -euo pipefail + +MODEL="${MODEL:-GEAR-Dreams/DreamZero-DROID}" +HOST="${HOST:-127.0.0.1}" +PORT="${PORT:-8000}" +DEPLOY_CONFIG="${DEPLOY_CONFIG:-vllm_omni/deploy/dreamzero_tp1_cfg2.yaml}" +SERVED_MODEL_NAME="${SERVED_MODEL_NAME:-dreamzero-droid}" + +args=( + serve + "$MODEL" + --omni + --host "$HOST" + --port "$PORT" + --served-model-name "$SERVED_MODEL_NAME" + --deploy-config "$DEPLOY_CONFIG" + --enforce-eager + --disable-log-stats +) + +ATTENTION_BACKEND="${ATTENTION_BACKEND:-torch}" \ +DIFFUSION_ATTENTION_BACKEND="${DIFFUSION_ATTENTION_BACKEND:-TORCH_SDPA}" \ +vllm "${args[@]}" diff --git a/tests/diffusion/test_diffusion_scheduler.py b/tests/diffusion/test_diffusion_scheduler.py index ca28b26294e..82c6f087075 100644 --- a/tests/diffusion/test_diffusion_scheduler.py +++ b/tests/diffusion/test_diffusion_scheduler.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio import queue import threading from types import SimpleNamespace @@ -484,11 +483,7 @@ def test_scheduler_alias_keeps_default_request_scheduler(self) -> None: @pytest.mark.asyncio async def test_step_raises_aborted_error(self, mocker: MockerFixture) -> None: engine = DiffusionEngine.__new__(DiffusionEngine) - engine._closed = False - engine._loop_started = True - engine._init_lock = asyncio.Lock() - engine.main_loop = asyncio.get_running_loop() - engine.stop_event = threading.Event() + engine._check_and_start_background_loop = mocker.AsyncMock() engine.pre_process_func = None engine.async_add_req_and_wait_for_response = mocker.AsyncMock( return_value=DiffusionOutput(aborted=True, abort_message="Request req-abort aborted.") @@ -563,6 +558,52 @@ def test_dummy_run_raises_on_output_error(self, mocker: MockerFixture) -> None: with pytest.raises(RuntimeError, match="Dummy run failed: boom"): engine._dummy_run() + @pytest.mark.asyncio + async def test_step_multi_request_reuses_multimodal_slice_logic(self, mocker: MockerFixture) -> None: + engine = DiffusionEngine.__new__(DiffusionEngine) + engine.od_config = SimpleNamespace( + model_class_name="mock_model", + enable_cpu_offload=False, + ) + engine.pre_process_func = None + engine.post_process_func = None + engine._check_and_start_background_loop = mocker.AsyncMock() + engine.async_add_req_and_wait_for_response = mocker.AsyncMock( + return_value=DiffusionOutput( + output={ + "video": ["frame-0", "frame-1"], + "audio": ["audio-0", "audio-1"], + "actions": torch.tensor([[1.0, 2.0], [3.0, 4.0]]), + } + ) + ) + + request = OmniDiffusionRequest( + prompts=["prompt-0", "prompt-1"], + sampling_params=OmniDiffusionSamplingParams( + num_inference_steps=1, + num_outputs_per_prompt=1, + ), + request_ids=["req-0", "req-1"], + ) + + mocker.patch("vllm_omni.diffusion.diffusion_engine.supports_audio_output", return_value=False) + outputs = await engine.step(request) + + assert len(outputs) == 2 + assert outputs[0].images == ["frame-0"] + assert outputs[1].images == ["frame-1"] + assert outputs[0].multimodal_output["audio"] == "audio-0" + assert outputs[1].multimodal_output["audio"] == "audio-1" + torch.testing.assert_close( + outputs[0].multimodal_output["actions"], + torch.tensor([1.0, 2.0]), + ) + torch.testing.assert_close( + outputs[1].multimodal_output["actions"], + torch.tensor([3.0, 4.0]), + ) + class TestStepScheduler: def setup_method(self) -> None: diff --git a/tests/dreamzero/openpi_client_helper.py b/tests/dreamzero/openpi_client_helper.py new file mode 100644 index 00000000000..8918e5c36c5 --- /dev/null +++ b/tests/dreamzero/openpi_client_helper.py @@ -0,0 +1,307 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np + +try: + import cv2 +except ImportError: # pragma: no cover - optional e2e dependency + cv2 = None + +try: + import websockets.sync.client as websockets_client +except ImportError: # pragma: no cover - optional e2e dependency + websockets_client = None + +try: + from openpi_client import msgpack_numpy +except ImportError: # pragma: no cover - optional e2e dependency + msgpack_numpy = None + +PING_INTERVAL_SECS = 300 +PING_TIMEOUT_SECS = 3600 +DEFAULT_HOST = "127.0.0.1" +DEFAULT_PORT = 8000 +DEFAULT_PATH = "/v1/realtime/robot/openpi" +DEFAULT_PROMPT = "Move the pan forward and use the brush in the middle of the plates to brush the inside of the pan" +ACTION_HORIZON = 24 +DEFAULT_ACTION_DIM = 8 +RELATIVE_OFFSETS = [-23, -16, -8, 0] +CAMERA_FILES = { + "observation/exterior_image_0_left": "exterior_image_1_left.mp4", + "observation/exterior_image_1_left": "exterior_image_2_left.mp4", + "observation/wrist_image_left": "wrist_image_left.mp4", +} + + +def _decode_action_response(response: bytes | str) -> np.ndarray: + if isinstance(response, str): + raise RuntimeError(f"Inference failed: {response}") + decoded = msgpack_numpy.unpackb(response) + if isinstance(decoded, dict) and decoded.get("type") == "error": + message = decoded.get("message", decoded) + raise RuntimeError(f"Inference failed: {message}") + return np.asarray(decoded, dtype=np.float32) + + +def require_dependencies() -> None: + missing = [] + if cv2 is None: + missing.append("opencv-python") + if websockets_client is None: + missing.append("websockets") + if msgpack_numpy is None: + missing.append("openpi-client") + if missing: + raise ModuleNotFoundError(f"DreamZero OpenPI test dependencies are missing: {', '.join(missing)}") + + +@dataclass(frozen=True) +class DreamZeroServerMetadata: + image_resolution: tuple[int, int] + n_external_cameras: int + needs_wrist_camera: bool + needs_stereo_camera: bool + needs_session_id: bool + action_space: str + + @classmethod + def from_dict(cls, payload: dict[str, Any]) -> DreamZeroServerMetadata: + required_keys = ( + "image_resolution", + "n_external_cameras", + "needs_wrist_camera", + "needs_stereo_camera", + "needs_session_id", + "action_space", + ) + missing_keys = [key for key in required_keys if key not in payload] + if missing_keys: + raise ValueError(f"Missing DreamZero metadata keys: {missing_keys}") + + image_resolution = payload["image_resolution"] + if not isinstance(image_resolution, (list, tuple)) or len(image_resolution) != 2: + raise ValueError(f"Invalid image_resolution: {image_resolution!r}") + + return cls( + image_resolution=(int(image_resolution[0]), int(image_resolution[1])), + n_external_cameras=int(payload["n_external_cameras"]), + needs_wrist_camera=bool(payload["needs_wrist_camera"]), + needs_stereo_camera=bool(payload["needs_stereo_camera"]), + needs_session_id=bool(payload["needs_session_id"]), + action_space=str(payload["action_space"]), + ) + + +class OpenPIWebsocketClient: + def __init__( + self, + *, + host: str = DEFAULT_HOST, + port: int = DEFAULT_PORT, + path: str = DEFAULT_PATH, + ) -> None: + require_dependencies() + self._uri = f"ws://{host}:{port}{path}" + self._packer = msgpack_numpy.Packer() + self._ws, self._server_metadata = self._connect() + + def _connect(self): + conn = websockets_client.connect( + self._uri, + compression=None, + max_size=None, + ping_interval=PING_INTERVAL_SECS, + ping_timeout=PING_TIMEOUT_SECS, + ) + metadata = msgpack_numpy.unpackb(conn.recv()) + if not isinstance(metadata, dict): + raise TypeError(f"Expected dict metadata from server, got {type(metadata)!r}") + return conn, metadata + + def get_server_metadata(self) -> dict[str, Any]: + return dict(self._server_metadata) + + def infer(self, obs: dict[str, Any]) -> np.ndarray: + payload = dict(obs) + payload["endpoint"] = "infer" + self._ws.send(self._packer.pack(payload)) + response = self._ws.recv() + return _decode_action_response(response) + + def reset(self, reset_info: dict[str, Any] | None = None) -> str: + payload = dict(reset_info or {}) + payload["endpoint"] = "reset" + self._ws.send(self._packer.pack(payload)) + response = self._ws.recv() + if isinstance(response, str): + return response + decoded = msgpack_numpy.unpackb(response) + if not isinstance(decoded, dict) or decoded.get("status") != "reset successful": + raise RuntimeError(f"Unexpected reset response: {decoded!r}") + return str(decoded["status"]) + + def close(self) -> None: + self._ws.close() + + +def load_all_frames(video_path: Path) -> np.ndarray: + require_dependencies() + cap = cv2.VideoCapture(str(video_path)) + frames = [] + while True: + ok, frame = cap.read() + if not ok: + break + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + cap.release() + if not frames: + raise RuntimeError(f"No frames loaded from {video_path}") + return np.stack(frames, axis=0) + + +def load_camera_frames(video_dir: Path) -> dict[str, np.ndarray]: + camera_frames: dict[str, np.ndarray] = {} + for camera_key, file_name in CAMERA_FILES.items(): + video_path = video_dir / file_name + if not video_path.exists(): + raise FileNotFoundError(f"Missing DreamZero test asset: {video_path}") + camera_frames[camera_key] = load_all_frames(video_path) + return camera_frames + + +def build_frame_schedule(total_frames: int, num_chunks: int) -> list[list[int]]: + chunks: list[list[int]] = [] + current_frame = 23 + for _ in range(num_chunks): + indices = [max(current_frame + offset, 0) for offset in RELATIVE_OFFSETS] + if indices[-1] >= total_frames: + break + chunks.append(indices) + current_frame += ACTION_HORIZON + return chunks + + +def make_obs_from_video( + camera_frames: dict[str, np.ndarray], + frame_indices: list[int], + *, + prompt: str, + session_id: str, +) -> dict[str, Any]: + obs: dict[str, Any] = {} + for camera_key, all_frames in camera_frames.items(): + selected = all_frames[frame_indices] + obs[camera_key] = selected[0] if len(frame_indices) == 1 else selected + + obs["observation/joint_position"] = np.zeros(7, dtype=np.float32) + obs["observation/cartesian_position"] = np.zeros(6, dtype=np.float32) + obs["observation/gripper_position"] = np.zeros(1, dtype=np.float32) + obs["prompt"] = prompt + obs["session_id"] = session_id + return obs + + +def build_demo_observations( + camera_frames: dict[str, np.ndarray], + *, + prompt: str, + session_id: str, + num_chunks: int = 2, +) -> list[dict[str, Any]]: + if num_chunks < 1: + raise ValueError("num_chunks must be at least 1") + + total_frames = min(frames.shape[0] for frames in camera_frames.values()) + observations = [ + make_obs_from_video( + camera_frames, + [0], + prompt=prompt, + session_id=session_id, + ) + ] + for indices in build_frame_schedule(total_frames, num_chunks - 1): + observations.append( + make_obs_from_video( + camera_frames, + indices, + prompt=prompt, + session_id=session_id, + ) + ) + return observations + + +def validate_session_result( + result: dict[str, Any], + *, + expected_action_horizon: int = ACTION_HORIZON, + expected_action_dim: int = DEFAULT_ACTION_DIM, +) -> None: + metadata = DreamZeroServerMetadata.from_dict(result["metadata"]) + if metadata.image_resolution != (180, 320): + raise AssertionError(f"Unexpected image_resolution: {metadata.image_resolution}") + if metadata.n_external_cameras != 2: + raise AssertionError(f"Unexpected n_external_cameras: {metadata.n_external_cameras}") + if not metadata.needs_wrist_camera: + raise AssertionError("DreamZero test expects wrist camera metadata") + if metadata.action_space != "joint_position": + raise AssertionError(f"Unexpected action_space: {metadata.action_space}") + + actions = result["actions"] + if len(actions) != 3: + raise AssertionError(f"Expected 3 action tensors, got {len(actions)}") + for index, action in enumerate(actions): + if action.shape != (expected_action_horizon, expected_action_dim): + raise AssertionError( + f"Action {index} shape mismatch: expected " + f"{(expected_action_horizon, expected_action_dim)}, got {action.shape}" + ) + if not np.isfinite(action).all(): + raise AssertionError(f"Action {index} contains non-finite values") + + if result["reset_status"] != "reset successful": + raise AssertionError(f"Unexpected reset status: {result['reset_status']!r}") + + +def run_policy_session( + *, + host: str = DEFAULT_HOST, + port: int = DEFAULT_PORT, + path: str = DEFAULT_PATH, + video_dir: Path, + prompt: str = DEFAULT_PROMPT, + session_id: str | None = None, + num_chunks: int = 2, +) -> dict[str, Any]: + session_id = session_id or str(uuid.uuid4()) + camera_frames = load_camera_frames(video_dir) + observations = build_demo_observations( + camera_frames, + prompt=prompt, + session_id=session_id, + num_chunks=num_chunks, + ) + + client = OpenPIWebsocketClient(host=host, port=port, path=path) + try: + metadata = client.get_server_metadata() + actions = [client.infer(obs) for obs in observations] + reset_status = client.reset({}) + actions.append(client.infer(observations[0])) + return { + "metadata": metadata, + "actions": actions, + "reset_status": reset_status, + "session_id": session_id, + } + finally: + client.close() diff --git a/tests/dreamzero/test_openpi_client_helper.py b/tests/dreamzero/test_openpi_client_helper.py new file mode 100644 index 00000000000..92d64061570 --- /dev/null +++ b/tests/dreamzero/test_openpi_client_helper.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pickle + +import numpy as np +import pytest + +from tests.dreamzero import openpi_client_helper + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class FakeMsgpackNumpy: + @staticmethod + def packb(obj): + return pickle.dumps(obj) + + @staticmethod + def unpackb(data): + return pickle.loads(data) + + +def test_decode_action_response_surfaces_structured_error(monkeypatch): + monkeypatch.setattr(openpi_client_helper, "msgpack_numpy", FakeMsgpackNumpy) + payload = FakeMsgpackNumpy.packb( + { + "type": "error", + "message": "Internal inference error", + } + ) + + with pytest.raises(RuntimeError, match="Internal inference error"): + openpi_client_helper._decode_action_response(payload) + + +def test_decode_action_response_converts_action_payload_to_float32(monkeypatch): + monkeypatch.setattr(openpi_client_helper, "msgpack_numpy", FakeMsgpackNumpy) + payload = FakeMsgpackNumpy.packb(np.asarray([[1.0, 2.0]], dtype=np.float64)) + + actions = openpi_client_helper._decode_action_response(payload) + + assert actions.dtype == np.float32 + np.testing.assert_array_equal(actions, np.asarray([[1.0, 2.0]], dtype=np.float32)) diff --git a/tests/dreamzero/test_pipeline_state.py b/tests/dreamzero/test_pipeline_state.py new file mode 100644 index 00000000000..a3a5aaec03c --- /dev/null +++ b/tests/dreamzero/test_pipeline_state.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections import OrderedDict + +import pytest +import torch + +from vllm_omni.diffusion.models.dreamzero.pipeline_dreamzero import DreamZeroPipeline +from vllm_omni.diffusion.models.dreamzero.state_dreamzero import DreamZeroState + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def _empty_pipeline() -> DreamZeroPipeline: + pipeline = DreamZeroPipeline.__new__(DreamZeroPipeline) + pipeline._states = OrderedDict() + pipeline._max_session_states = 2 + return pipeline + + +def test_dreamzero_pipeline_state_is_session_keyed() -> None: + pipeline = _empty_pipeline() + + session_a = pipeline._get_or_create_state("session-a") + session_b = pipeline._get_or_create_state("session-b") + session_a.call_count = 7 + session_b.call_count = 3 + + assert pipeline._get_or_create_state("session-a") is session_a + assert pipeline._get_or_create_state("session-b") is session_b + assert session_a.call_count == 7 + assert session_b.call_count == 3 + + +def test_dreamzero_pipeline_state_lru_caps_retained_sessions() -> None: + pipeline = _empty_pipeline() + + session_a = pipeline._get_or_create_state("session-a") + pipeline._get_or_create_state("session-b") + assert pipeline._get_or_create_state("session-a") is session_a + + pipeline._get_or_create_state("session-c") + + assert list(pipeline._states) == ["session-a", "session-c"] + assert "session-b" not in pipeline._states + + +def test_dreamzero_state_cache_access_requires_initialization() -> None: + state = DreamZeroState() + + with pytest.raises(RuntimeError, match="KV caches not initialized"): + state.get_kv_caches() + + with pytest.raises(RuntimeError, match="Cross-attn caches not initialized"): + state.get_crossattn_caches() + + with pytest.raises(RuntimeError, match="create_kv_caches first"): + state.update_kv_cache(0, torch.empty(0)) diff --git a/tests/dreamzero/test_utils.py b/tests/dreamzero/test_utils.py new file mode 100644 index 00000000000..c399bd70b1b --- /dev/null +++ b/tests/dreamzero/test_utils.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm_omni.diffusion.models.dreamzero.utils import ( + DEFAULT_CFG_SCALE, + DEFAULT_EMBODIMENT_NAME_TO_ID, + DEFAULT_NEGATIVE_PROMPT, + DEFAULT_NUM_INFERENCE_STEPS, + DEFAULT_SEED, + DEFAULT_SIGMA_SHIFT, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def test_dreamzero_default_constants_match_source_baseline(): + assert DEFAULT_NUM_INFERENCE_STEPS == 16 + assert DEFAULT_CFG_SCALE == 5.0 + assert DEFAULT_SIGMA_SHIFT == 5.0 + assert DEFAULT_SEED == 1140 + assert "worst quality" in DEFAULT_NEGATIVE_PROMPT + assert DEFAULT_EMBODIMENT_NAME_TO_ID["oxe_droid"] == 17 diff --git a/tests/dreamzero/upstream/openpi_test_client_ar.py b/tests/dreamzero/upstream/openpi_test_client_ar.py new file mode 100644 index 00000000000..7c5aea1d901 --- /dev/null +++ b/tests/dreamzero/upstream/openpi_test_client_ar.py @@ -0,0 +1,381 @@ +#!/usr/bin/env python3 +""" +Copied from the DreamZero repository's `test_client_AR.py`: +https://github.com/dreamzero0/dreamzero/blob/main/test_client_AR.py + +Kept here for end-to-end compatibility / parity testing against the vLLM +OpenPI server. + +Test client for AR_droid policy server using roboarena interface. + +Sends real video frames from debug_image/ directory instead of zero dummy images. + +Frame schedule (matching debug_inference.py): + - Step 0 (initial): send frame [0] (1 frame, H W 3) + - Step 1: send frames [0, 7, 15, 23] (4 frames, 4 H W 3) + - Step 2: send frames [24, 31, 39, 47] (4 frames) + - Step 3: send frames [48, 55, 63, 71] (4 frames) + - ... + +Expected server configuration: + - image_resolution: (180, 320) + - n_external_cameras: 2 + - needs_wrist_camera: True + - action_space: "joint_position" + +Usage: + # Start server with roboarena interface: + torchrun --nproc_per_node=8 socket_test_optimized_AR.py --port 8000 + + # Run against the original DreamZero websocket server (default path): + python test_client_AR.py --host --port 8000 + + # Run against the vLLM OpenPI server: + python test_client_AR.py --host --port 8000 --path /v1/realtime/robot/openpi + + # Use zero images instead of real video (old behavior): + python test_client_AR.py --host --port 8000 --use-zero-images +""" + +import argparse +import logging +import os +import sys +import time +import uuid +from pathlib import Path + +import cv2 +import numpy as np +from openpi_client import msgpack_numpy + +_DREAMZERO_REPO_ENV = os.environ.get("DREAMZERO_REPO") +DREAMZERO_REPO = Path(_DREAMZERO_REPO_ENV).expanduser() if _DREAMZERO_REPO_ENV else None + + +def _import_upstream_policy_modules(): + if DREAMZERO_REPO is None: + raise ImportError("Set DREAMZERO_REPO to an upstream DreamZero checkout before using this helper.") + if DREAMZERO_REPO.exists() and str(DREAMZERO_REPO) not in sys.path: + sys.path.insert(0, str(DREAMZERO_REPO)) + + import eval_utils.policy_server as policy_server + from eval_utils.policy_client import WebsocketClientPolicy + + return policy_server, WebsocketClientPolicy + + +policy_server, WebsocketClientPolicy = _import_upstream_policy_modules() + +VIDEO_DIR = os.environ.get( + "DREAMZERO_VIDEO_DIR", + str(DREAMZERO_REPO / "debug_image") if DREAMZERO_REPO is not None else "debug_image", +) + +# roboarena key -> video filename +CAMERA_FILES = { + "observation/exterior_image_0_left": "exterior_image_1_left.mp4", + "observation/exterior_image_1_left": "exterior_image_2_left.mp4", + "observation/wrist_image_left": "wrist_image_left.mp4", +} + +# Frame schedule constants (matching debug_inference.py) +RELATIVE_OFFSETS = [-23, -16, -8, 0] +ACTION_HORIZON = 24 +DEFAULT_WEBSOCKET_PATH = "" + + +class OpenPIWebsocketClientPolicy(WebsocketClientPolicy): + """DreamZero websocket client with a configurable path suffix. + + The original DreamZero client connects to ``ws://host:port``. + vLLM serves the compatible robot policy endpoint at + ``/v1/realtime/robot/openpi`` when ``path`` is set accordingly. + """ + + def __init__( + self, + host: str = "0.0.0.0", + port: int = 8000, + path: str = DEFAULT_WEBSOCKET_PATH, + ) -> None: + self._uri = f"ws://{host}:{port}{path}" + self._packer = msgpack_numpy.Packer() + self._ws, self._server_metadata = self._wait_for_server() + + +def load_all_frames(video_path: str) -> np.ndarray: + """Load all frames from a video file. Returns (N, H, W, 3) uint8 array (RGB).""" + cap = cv2.VideoCapture(video_path) + frames = [] + while True: + ret, frame = cap.read() + if not ret: + break + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + cap.release() + if not frames: + raise RuntimeError(f"No frames loaded from {video_path}") + return np.stack(frames, axis=0) + + +def load_camera_frames() -> dict[str, np.ndarray]: + """Load all video frames for each camera from the debug_image/ directory. + + Returns: + Dict mapping roboarena camera keys to (N, H, W, 3) uint8 arrays. + """ + camera_frames: dict[str, np.ndarray] = {} + for cam_key, fname in CAMERA_FILES.items(): + path = os.path.join(VIDEO_DIR, fname) + camera_frames[cam_key] = load_all_frames(path) + logging.info(f"Loaded {cam_key}: {camera_frames[cam_key].shape}") + return camera_frames + + +def build_frame_schedule(total_frames: int, num_chunks: int) -> list[list[int]]: + """Build the frame index schedule for multi-frame chunks. + + Returns a list of frame-index lists. Each inner list has 4 indices. + """ + chunks: list[list[int]] = [] + current_frame = 23 # first anchor frame + for _ in range(num_chunks): + indices = [max(current_frame + off, 0) for off in RELATIVE_OFFSETS] + if indices[-1] >= total_frames: + logging.info(f"Frame {indices[-1]} >= {total_frames}, stopping at {len(chunks)} chunks") + break + chunks.append(indices) + current_frame += ACTION_HORIZON + return chunks + + +def _make_obs_from_video( + camera_frames: dict[str, np.ndarray], + frame_indices: list[int], + prompt: str, + session_id: str, +) -> dict: + """Build an observation dict from real video frames. + + For 1 frame: each image key is (H, W, 3). + For 4 frames: each image key is (4, H, W, 3). + """ + obs: dict = {} + for cam_key, all_frames in camera_frames.items(): + selected = all_frames[frame_indices] # (T, H, W, 3) + if len(frame_indices) == 1: + selected = selected[0] # (H, W, 3) + obs[cam_key] = selected + + obs["observation/joint_position"] = np.zeros(7, dtype=np.float32) + obs["observation/cartesian_position"] = np.zeros(6, dtype=np.float32) + obs["observation/gripper_position"] = np.zeros(1, dtype=np.float32) + obs["prompt"] = prompt + obs["session_id"] = session_id + return obs + + +def _make_zero_observation( + server_config: policy_server.PolicyServerConfig, + prompt: str = "pick up the object", + session_id: str | None = None, +) -> dict: + """Create a dummy observation matching AR_droid expectations. + + AR_droid expects: + - 2 external cameras (exterior_image_0_left, exterior_image_1_left) + - 1 wrist camera (wrist_image_left) + - Image resolution: 180x320 (H x W) + - joint_position: 7 DoF + - gripper_position: 1 DoF + """ + obs = {} + + # Determine image resolution + if server_config.image_resolution is not None: + h, w = server_config.image_resolution + else: + # Default for AR_droid + h, w = 180, 320 + + # External cameras (0-indexed in roboarena) + for i in range(server_config.n_external_cameras): + obs[f"observation/exterior_image_{i}_left"] = np.zeros((h, w, 3), dtype=np.uint8) + if server_config.needs_stereo_camera: + obs[f"observation/exterior_image_{i}_right"] = np.zeros((h, w, 3), dtype=np.uint8) + + # Wrist camera + if server_config.needs_wrist_camera: + obs["observation/wrist_image_left"] = np.zeros((h, w, 3), dtype=np.uint8) + if server_config.needs_stereo_camera: + obs["observation/wrist_image_right"] = np.zeros((h, w, 3), dtype=np.uint8) + + # Session ID - should be passed in to ensure consistency within a session + if server_config.needs_session_id: + import uuid + + # Generate unique session ID if not provided + obs["session_id"] = session_id if session_id else str(uuid.uuid4()) + + # State observations (AR_droid: 7 DoF arm + 1 gripper) + obs["observation/joint_position"] = np.zeros(7, dtype=np.float32) + obs["observation/cartesian_position"] = np.zeros(6, dtype=np.float32) + obs["observation/gripper_position"] = np.zeros(1, dtype=np.float32) + + # Language prompt + obs["prompt"] = prompt + + return obs + + +def test_ar_droid_policy_server( + host: str = "localhost", + port: int = 8000, + path: str = DEFAULT_WEBSOCKET_PATH, + num_chunks: int = 15, + prompt: str = "Move the pan forward and use the brush in the middle of the plates to brush the inside of the pan", + use_zero_images: bool = False, +): + """Test the AR_droid policy server with roboarena interface. + + When use_zero_images is False (default), loads real video frames from + debug_image/ and follows the frame schedule from debug_inference.py. + """ + logging.info(f"Connecting to AR_droid server at ws://{host}:{port}{path} ...") + + client = OpenPIWebsocketClientPolicy(host=host, port=port, path=path) + + # Validate server metadata + metadata = client.get_server_metadata() + logging.info(f"Server metadata: {metadata}") + assert isinstance(metadata, dict), "Metadata should be a dict" + + try: + server_config = policy_server.PolicyServerConfig(**metadata) + except Exception as e: + logging.error(f"Error parsing metadata: {e}") + raise e + + # Validate expected AR_droid configuration + logging.info(f"Server config: {server_config}") + assert server_config.n_external_cameras == 2, f"Expected 2 external cameras, got {server_config.n_external_cameras}" + assert server_config.needs_wrist_camera, "Expected wrist camera to be enabled" + assert server_config.action_space == "joint_position", ( + f"Expected joint_position action space, got {server_config.action_space}" + ) + + logging.info("Server configuration validated for AR_droid") + + # Generate unique session ID for this test run + session_id = str(uuid.uuid4()) + logging.info(f"Session ID: {session_id}") + + # ── Zero-image fallback mode ────────────────────────────────────── + if use_zero_images: + logging.info("Using ZERO dummy images (legacy mode)") + for i in range(num_chunks): + obs = _make_zero_observation(server_config, prompt=prompt, session_id=session_id) + logging.info(f"Inference {i + 1}/{num_chunks}: prompt='{prompt}'") + t0 = time.time() + actions = client.infer(obs) + dt = time.time() - t0 + _log_action(actions, dt) + + logging.info("Sending reset...") + client.reset({}) + logging.info("Done (zero-image mode).") + return + + # ── Real video frame mode ───────────────────────────────────────── + logging.info("Loading real video frames from debug_image/ directory") + camera_frames = load_camera_frames() + + total_frames = min(v.shape[0] for v in camera_frames.values()) + logging.info(f"Total frames available: {total_frames}") + + # Build frame schedule + chunks = build_frame_schedule(total_frames, num_chunks) + + logging.info("Frame schedule:") + logging.info(" Initial: [0]") + for i, indices in enumerate(chunks): + logging.info(f" Chunk {i}: {indices}") + + # Step 0: initial single frame + logging.info("=== Initial: frame [0] ===") + obs = _make_obs_from_video(camera_frames, [0], prompt, session_id) + t0 = time.time() + actions = client.infer(obs) + dt = time.time() - t0 + _log_action(actions, dt) + + # Subsequent chunks: send 4 frames at a time + for chunk_idx, frame_indices in enumerate(chunks): + logging.info(f"=== Chunk {chunk_idx}: frames {frame_indices} ===") + obs = _make_obs_from_video(camera_frames, frame_indices, prompt, session_id) + t0 = time.time() + actions = client.infer(obs) + dt = time.time() - t0 + _log_action(actions, dt) + + # Reset triggers video save on the server + logging.info("Sending reset to save video...") + client.reset({}) + + logging.info("Done.") + + +def _log_action(actions: np.ndarray, dt: float) -> None: + """Pretty-print action shape, range, and timing.""" + assert isinstance(actions, np.ndarray), f"Expected numpy array, got {type(actions)}" + assert actions.ndim == 2, f"Expected 2D array, got shape {actions.shape}" + assert actions.shape[-1] == 8, f"Expected 8 action dims (7 joints + 1 gripper), got {actions.shape[-1]}" + logging.info(f" Action shape: {actions.shape}, range: [{actions.min():.4f}, {actions.max():.4f}], time: {dt:.2f}s") + + +def main(): + parser = argparse.ArgumentParser(description="Test AR_droid policy server with real video frames from debug_image/") + parser.add_argument("--host", default="localhost", help="Server hostname") + parser.add_argument("--port", type=int, default=8000, help="Server port") + parser.add_argument( + "--path", + default=DEFAULT_WEBSOCKET_PATH, + help="WebSocket path suffix (default: empty string for the original DreamZero server)", + ) + parser.add_argument( + "--num-chunks", + type=int, + default=15, + help="Number of 4-frame chunks to send after the initial frame (default: 15)", + ) + parser.add_argument( + "--prompt", + default="Move the pan forward and use the brush in the middle of the plates to brush the inside of the pan", + help="Language prompt for the policy", + ) + parser.add_argument( + "--use-zero-images", + action="store_true", + help="Use zero dummy images instead of real video frames (legacy mode)", + ) + + args = parser.parse_args() + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + ) + + test_ar_droid_policy_server( + host=args.host, + port=args.port, + path=args.path, + num_chunks=args.num_chunks, + prompt=args.prompt, + use_zero_images=args.use_zero_images, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/dreamzero/upstream/test_openpi_e2e_source_parity.py b/tests/dreamzero/upstream/test_openpi_e2e_source_parity.py new file mode 100644 index 00000000000..e41cff9032c --- /dev/null +++ b/tests/dreamzero/upstream/test_openpi_e2e_source_parity.py @@ -0,0 +1,358 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Formal OpenPI end-to-end parity: upstream DreamZero server vs `vllm serve`. + +This test uses DreamZero's own client-side observation builders from +`${DREAMZERO_REPO}/test_client_AR.py`, and client-side websocket protocol from +`${DREAMZERO_REPO}/eval_utils/policy_client.py`. + +The only client-side adaptation for vLLM is the websocket path: +DreamZero's upstream server serves at `/`, while vLLM serves OpenPI at +`/v1/realtime/robot/openpi`. + +Current scope for this test: +- default two-GPU run (`nproc_per_node=2` on upstream, `--cfg-parallel-size 2` on `vllm serve`) +- non-`torch.compile` (upstream launched through + `upstream_socket_server_no_compile.py`, vLLM with `--enforce-eager`) +- non-DiT-cache / non-skip-schedule (`NUM_DIT_STEPS=16`) + +Serving contract locked by this test: +- upstream DreamZero still boots from the local checkpoint directory +- vLLM boots from the official DreamZero HF repo name (`GEAR-Dreams/DreamZero-DROID`) + rather than a prepared local bundle path +""" + +from __future__ import annotations + +import os +import shutil +import subprocess +import sys +import time +from pathlib import Path + +import numpy as np +import pytest +import torch + +from tests.helpers.runtime import get_open_port + +msgpack_numpy = pytest.importorskip("openpi_client.msgpack_numpy") + +_DREAMZERO_REPO_ENV = os.environ.get("DREAMZERO_REPO") +DREAMZERO_REPO = Path(_DREAMZERO_REPO_ENV).expanduser() if _DREAMZERO_REPO_ENV else None +if DREAMZERO_REPO is not None and str(DREAMZERO_REPO) not in sys.path: + sys.path.insert(0, str(DREAMZERO_REPO)) + +try: + import test_client_AR as dreamzero_client + from eval_utils.policy_client import WebsocketClientPolicy +except Exception: # pragma: no cover - guarded by pytest skip below + dreamzero_client = None + WebsocketClientPolicy = None +_BaseWebsocketClientPolicy = WebsocketClientPolicy if WebsocketClientPolicy is not None else object + +CHECKPOINT_DIR = DREAMZERO_REPO / "checkpoints" / "dreamzero" if DREAMZERO_REPO is not None else None +VLLM_MODEL = os.environ.get("VLLM_DREAMZERO_MODEL", "GEAR-Dreams/DreamZero-DROID") +SERVICE_READY_TIMEOUT_S = int(os.environ.get("OPENPI_SERVICE_READY_TIMEOUT_S", "900")) +PROMPT = "Move the pan forward and use the brush in the middle of the plates to brush the inside of the pan" +SESSION_ID = "openpi-e2e-parity-session" + +pytestmark = [ + pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU required"), + pytest.mark.skipif( + dreamzero_client is None or WebsocketClientPolicy is None, + reason="DreamZero client modules are required on PYTHONPATH", + ), + pytest.mark.skipif( + DREAMZERO_REPO is None or not DREAMZERO_REPO.exists(), + reason="DreamZero source repo is required at DREAMZERO_REPO", + ), + pytest.mark.skipif( + CHECKPOINT_DIR is None or not CHECKPOINT_DIR.exists(), reason="DreamZero local checkpoint is required" + ), +] + + +class OpenPIWebsocketClientPolicy(_BaseWebsocketClientPolicy): + """DreamZero client protocol with an OpenPI websocket path suffix.""" + + def __init__( + self, + host: str = "127.0.0.1", + port: int = 8000, + path: str = "/v1/realtime/robot/openpi", + ) -> None: + self._uri = f"ws://{host}:{port}{path}" + self._packer = msgpack_numpy.Packer() + self._ws, self._server_metadata = self._wait_for_server() + + +def _vllm_executable() -> str: + fallback = Path(sys.executable).with_name("vllm") + if fallback.exists(): + return str(fallback) + exe = shutil.which("vllm") + if exe: + return exe + raise FileNotFoundError("Unable to locate `vllm` executable in current environment.") + + +def _cfg_parallel_size() -> int: + return int(os.environ.get("OPENPI_E2E_CFG_PARALLEL_SIZE", "2")) + + +def _pick_test_gpus() -> list[str]: + cfg_parallel_size = _cfg_parallel_size() + override = os.environ.get("OPENPI_E2E_GPUS") or os.environ.get("OPENPI_E2E_GPU") + if override is not None: + gpus = [part.strip() for part in override.split(",") if part.strip()] + if not gpus: + raise ValueError("OPENPI_E2E_GPUS is set but empty.") + if len(gpus) < cfg_parallel_size: + raise RuntimeError(f"Need {cfg_parallel_size} GPUs, but OPENPI_E2E_GPUS only provided {gpus}.") + return gpus + + query = subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=index,memory.used", + "--format=csv,noheader,nounits", + ], + text=True, + ) + gpu_rows = [] + for line in query.strip().splitlines(): + gpu_index, used_mb = [part.strip() for part in line.split(",", maxsplit=1)] + gpu_rows.append((int(used_mb), gpu_index)) + gpu_rows.sort() + gpus = [gpu_index for _, gpu_index in gpu_rows[: max(cfg_parallel_size, 1)]] + if len(gpus) < cfg_parallel_size: + raise RuntimeError( + f"Need {cfg_parallel_size} GPUs for cfg_parallel_size={cfg_parallel_size}, " + f"but found only {len(gpus)} candidates." + ) + return gpus + + +def _torchrun_argv(script: str, port: int) -> list[str]: + return [ + sys.executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node", + str(_cfg_parallel_size()), + script, + "--port", + str(port), + "--model_path", + str(CHECKPOINT_DIR), + ] + + +def _run_upstream_service(port: int, log_path: Path) -> subprocess.Popen[str]: + env = os.environ.copy() + env.setdefault("PYTHONPATH", "") + env["PYTHONPATH"] = f"{Path.cwd()}:{DREAMZERO_REPO}:{env['PYTHONPATH']}".rstrip(":") + env["CUDA_VISIBLE_DEVICES"] = ",".join(_pick_test_gpus()) + env.setdefault("NO_ALBUMENTATIONS_UPDATE", "1") + env["ATTENTION_BACKEND"] = "torch" + env.setdefault("ENABLE_TENSORRT", "false") + env["ENABLE_DIT_CACHE"] = "false" + env["NUM_DIT_STEPS"] = "16" + env["DYNAMIC_CACHE_SCHEDULE"] = "false" + argv = _torchrun_argv( + str(Path("tests/dreamzero/upstream/upstream_socket_server_no_compile.py")), + port, + ) + log_file = log_path.open("w") + proc = subprocess.Popen( + argv, + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + cwd=str(Path.cwd()), + env=env, + ) + proc._codex_log_file = log_file # type: ignore[attr-defined] + return proc + + +def _run_vllm_service(port: int, log_path: Path) -> subprocess.Popen[str]: + env = os.environ.copy() + gpus = _pick_test_gpus() + cfg_parallel_size = _cfg_parallel_size() + if cfg_parallel_size > len(gpus): + raise RuntimeError( + f"cfg_parallel_size={cfg_parallel_size} requires at least {cfg_parallel_size} GPUs, but only got {gpus}." + ) + env["CUDA_VISIBLE_DEVICES"] = ",".join(gpus[:cfg_parallel_size]) + env.setdefault("ATTENTION_BACKEND", "torch") + env.setdefault("DIFFUSION_ATTENTION_BACKEND", "TORCH_SDPA") + env.setdefault("MASTER_PORT", str(get_open_port())) + argv = [ + _vllm_executable(), + "serve", + VLLM_MODEL, + "--omni", + "--host", + "127.0.0.1", + "--port", + str(port), + "--served-model-name", + "dreamzero-droid", + "--enforce-eager", + ] + if cfg_parallel_size > 1: + argv.extend(["--cfg-parallel-size", str(cfg_parallel_size)]) + log_file = log_path.open("w") + proc = subprocess.Popen( + argv, + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + env=env, + cwd=str(Path.cwd()), + ) + proc._codex_log_file = log_file # type: ignore[attr-defined] + return proc + + +def _stop_process(proc: subprocess.Popen[str]) -> None: + log_file = getattr(proc, "_codex_log_file", None) + if proc.poll() is None: + proc.terminate() + try: + proc.wait(timeout=30) + except subprocess.TimeoutExpired: # pragma: no cover - cleanup path + proc.kill() + proc.wait(timeout=10) + if log_file is not None: + log_file.close() + + +def _build_obs_sequence() -> tuple[dict, dict]: + camera_frames = dreamzero_client.load_camera_frames() + chunks = dreamzero_client.build_frame_schedule( + min(v.shape[0] for v in camera_frames.values()), + 1, + ) + obs0 = dreamzero_client._make_obs_from_video(camera_frames, [0], PROMPT, SESSION_ID) + obs1 = dreamzero_client._make_obs_from_video(camera_frames, chunks[0], PROMPT, SESSION_ID) + return obs0, obs1 + + +def _wait_for_client_ready(client_factory, timeout_s: float, proc=None, log_path: Path | None = None): + deadline = time.time() + timeout_s + last_err: Exception | None = None + while time.time() < deadline: + if proc is not None and proc.poll() is not None: + details = "" + if log_path is not None and log_path.exists(): + details = log_path.read_text(errors="replace")[-8000:] + raise RuntimeError(f"Service exited before becoming ready with code {proc.returncode}.\n{details}") + try: + return client_factory() + except Exception as exc: # pragma: no cover - retry path + last_err = exc + time.sleep(1) + raise TimeoutError(f"Timed out waiting for websocket service: {last_err}") + + +def _collect_outputs_with_client(client) -> tuple[dict, list[np.ndarray]]: + metadata = client.get_server_metadata() + obs0, obs1 = _build_obs_sequence() + outputs = [ + client.infer(dict(obs0)), + client.infer(dict(obs1)), + ] + assert _normalize_reset_response(client.reset({})) == "reset successful" + outputs.append(client.infer(dict(obs0))) + client._ws.close() + return metadata, outputs + + +def _normalize_reset_response(response) -> str: + if isinstance(response, str): + return response + decoded = msgpack_numpy.unpackb(response) + if isinstance(decoded, dict): + return str(decoded.get("status")) + return str(decoded) + + +def _normalize_metadata(metadata: dict) -> dict: + normalized = dict(metadata) + if isinstance(normalized.get("image_resolution"), tuple): + normalized["image_resolution"] = list(normalized["image_resolution"]) + return normalized + + +def _assert_logs_clean(log_path: Path) -> None: + text = log_path.read_text(errors="replace") + if "SignalException: Process" in text and "got signal: 15" in text: + text = text.split("Traceback (most recent call last):", 1)[0] + assert "Traceback" not in text, text + assert "RuntimeError:" not in text, text + + +def _assert_upstream_log_matches_vllm_baseline(log_path: Path) -> None: + text = log_path.read_text(errors="replace") + assert "DIT Compute Steps 8 steps" not in text, text + assert "DIT Compute Steps 16 steps" in text, text + + +def test_openpi_service_matches_upstream_server_noncompile(tmp_path: Path) -> None: + expected_metadata = { + "image_resolution": [180, 320], + "n_external_cameras": 2, + "needs_wrist_camera": True, + "needs_stereo_camera": False, + "needs_session_id": True, + "action_space": "joint_position", + } + + upstream_port = get_open_port() + upstream_log = tmp_path / "dreamzero_upstream.log" + upstream_proc = _run_upstream_service(upstream_port, upstream_log) + try: + upstream_client = _wait_for_client_ready( + lambda: WebsocketClientPolicy(host="127.0.0.1", port=upstream_port), + timeout_s=SERVICE_READY_TIMEOUT_S, + proc=upstream_proc, + log_path=upstream_log, + ) + upstream_metadata, upstream_outputs = _collect_outputs_with_client(upstream_client) + finally: + _stop_process(upstream_proc) + _assert_logs_clean(upstream_log) + _assert_upstream_log_matches_vllm_baseline(upstream_log) + + vllm_port = get_open_port() + vllm_log = tmp_path / "vllm_openpi.log" + vllm_proc = _run_vllm_service(vllm_port, vllm_log) + try: + vllm_client = _wait_for_client_ready( + lambda: OpenPIWebsocketClientPolicy(host="127.0.0.1", port=vllm_port), + timeout_s=SERVICE_READY_TIMEOUT_S, + proc=vllm_proc, + log_path=vllm_log, + ) + vllm_metadata, vllm_outputs = _collect_outputs_with_client(vllm_client) + finally: + _stop_process(vllm_proc) + _assert_logs_clean(vllm_log) + + assert _normalize_metadata(upstream_metadata) == expected_metadata + assert _normalize_metadata(vllm_metadata) == expected_metadata + + for idx, (actual, expected) in enumerate(zip(vllm_outputs, upstream_outputs, strict=True)): + np.testing.assert_allclose( + actual, + expected, + rtol=1e-2, + atol=1e-3, + err_msg=f"OpenPI step {idx} output mismatch", + ) diff --git a/tests/dreamzero/upstream/upstream_socket_server_no_compile.py b/tests/dreamzero/upstream/upstream_socket_server_no_compile.py new file mode 100644 index 00000000000..a9ef7a3d00d --- /dev/null +++ b/tests/dreamzero/upstream/upstream_socket_server_no_compile.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Launch the upstream DreamZero websocket server with `torch.compile` disabled. + +This wrapper is meant for formal parity tests against `vllm serve --omni`. +It monkeypatches `torch.compile` before importing DreamZero modules so that +all import-time decorators and `post_initialize()` compile calls become eager. + +For the current DreamZero port baseline we also disable all upstream DiT cache +and step-skipping behavior so the reference server matches vLLM's current +eager/no-skip implementation: +- `ENABLE_DIT_CACHE=false` +- `NUM_DIT_STEPS=16` +- `DYNAMIC_CACHE_SCHEDULE=false` + +In our CI/dev environment we also do not install Transformer Engine or +FlashAttention, so upstream's default `socket_test_optimized_AR.main()` path +(`ATTENTION_BACKEND=TE` and then fallback to `FA2`) cannot execute. To keep the +formal server-vs-server test runnable without pulling in those heavyweight +optional deps, this wrapper reproduces upstream `main()` but pins attention to +PyTorch SDPA (`ATTENTION_BACKEND=torch`) for this subprocess only. + +Usage: + PYTHONPATH="${DREAMZERO_REPO}" \\ + .venv/bin/python -m torch.distributed.run --standalone --nproc_per_node=2 \\ + tests/dreamzero/upstream/upstream_socket_server_no_compile.py --port 18081 \\ + --model_path "${DREAMZERO_REPO}/checkpoints/dreamzero" +""" + +from __future__ import annotations + +import os +import sys +from pathlib import Path + +import torch +import torch.nn.functional as F + + +def _identity_compile(*args, **kwargs): + if args and callable(args[0]) and len(args) == 1 and not kwargs: + return args[0] + + def deco(fn): + return fn + + return deco + + +torch.compile = _identity_compile + +os.environ.setdefault("NO_ALBUMENTATIONS_UPDATE", "1") + +DREAMZERO_REPO_ENV = os.environ.get("DREAMZERO_REPO") +if not DREAMZERO_REPO_ENV: + raise RuntimeError("Set DREAMZERO_REPO to an upstream DreamZero checkout before launching this helper.") +DREAMZERO_REPO = Path(DREAMZERO_REPO_ENV).expanduser() +if str(DREAMZERO_REPO) not in sys.path: + sys.path.insert(0, str(DREAMZERO_REPO)) + +import socket_test_optimized_AR as upstream # noqa: E402 +import tyro # noqa: E402 +from groot.vla.model.dreamzero.modules import attention as upstream_attention # noqa: E402 +from groot.vla.model.dreamzero.modules import wan2_1_submodule as upstream_submodule # noqa: E402 +from groot.vla.model.dreamzero.modules import wan_video_dit as upstream_wan_video_dit # noqa: E402 + + +def _torch_varlen_flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q_lens: torch.Tensor | None = None, + k_lens: torch.Tensor | None = None, + dropout_p: float = 0.0, + causal: bool = False, + dtype: torch.dtype = torch.bfloat16, + **_: object, +) -> torch.Tensor: + if q_lens is not None or k_lens is not None: + upstream_attention.warnings.warn( + "Padding mask is disabled in the test-only SDPA fallback.", + ) + out_dtype = q.dtype + q = q.transpose(1, 2).to(dtype) + k = k.transpose(1, 2).to(dtype) + v = v.transpose(1, 2).to(dtype) + out = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=causal, + dropout_p=dropout_p, + ) + return out.transpose(1, 2).contiguous().to(out_dtype) + + +upstream_attention.flash_attention = _torch_varlen_flash_attention +upstream_submodule.flash_attention = _torch_varlen_flash_attention +upstream_wan_video_dit.FLASH_ATTN_3_AVAILABLE = False +upstream_wan_video_dit.FLASH_ATTN_2_AVAILABLE = False +upstream_wan_video_dit.SAGE_ATTN_AVAILABLE = False + + +def main(args: upstream.Args) -> None: + os.environ["ENABLE_DIT_CACHE"] = "false" + os.environ["ATTENTION_BACKEND"] = "torch" + os.environ["NUM_DIT_STEPS"] = "16" + os.environ["DYNAMIC_CACHE_SCHEDULE"] = "false" + torch._dynamo.config.recompile_limit = 800 + + embodiment_tag = "oxe_droid" + model_path = args.model_path + + device_mesh = upstream.init_mesh() + rank = upstream.dist.get_rank() + + timeout_delta = upstream.datetime.timedelta(seconds=args.timeout_seconds) + signal_group = upstream.dist.new_group(backend="gloo", timeout=timeout_delta) + upstream.logger.info("Rank %s initialized signal_group (gloo)", rank) + + policy = upstream.GrootSimPolicy( + embodiment_tag=upstream.EmbodimentTag(embodiment_tag), + model_path=model_path, + device="cuda" if torch.cuda.is_available() else "cpu", + device_mesh=device_mesh, + ) + + hostname = upstream.socket.gethostname() + local_ip = upstream.socket.gethostbyname(hostname) + + if rank == 0: + upstream.logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip) + parent_dir = os.path.dirname(model_path) + date_suffix = upstream.datetime.datetime.now().strftime("%Y%m%d") + checkpoint_name = os.path.basename(model_path) + output_dir = os.path.join( + parent_dir, + f"real_world_eval_gen_{date_suffix}_{args.index}", + checkpoint_name, + ) + os.makedirs(output_dir, exist_ok=True) + upstream.logging.info("Videos will be saved to: %s", output_dir) + else: + output_dir = None + upstream.logging.info("Rank %s starting as worker for distributed inference...", rank) + + wrapper_policy = upstream.ARDroidRoboarenaPolicy( + groot_policy=policy, + signal_group=signal_group, + output_dir=output_dir, + ) + + server_config = upstream.PolicyServerConfig( + image_resolution=(180, 320), + needs_wrist_camera=True, + n_external_cameras=2, + needs_stereo_camera=False, + needs_session_id=True, + action_space="joint_position", + ) + + if rank == 0: + upstream.logging.info("Using roboarena policy server interface") + upstream.logging.info("Server config: %s", server_config) + roboarena_server = upstream.RoboarenaServer( + policy=wrapper_policy, + server_config=server_config, + host="0.0.0.0", + port=args.port, + ) + roboarena_server.serve_forever() + else: + worker = upstream.DistributedWorker( + policy=policy, + signal_group=signal_group, + ) + upstream.asyncio.run(worker.run()) + + +if __name__ == "__main__": + upstream.logging.basicConfig(level=upstream.logging.INFO, force=True) + main(tyro.cli(upstream.Args)) diff --git a/tests/e2e/online_serving/test_dreamzero_expansion.py b/tests/e2e/online_serving/test_dreamzero_expansion.py new file mode 100644 index 00000000000..7a2fd3ac98f --- /dev/null +++ b/tests/e2e/online_serving/test_dreamzero_expansion.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""E2E online serving test for DreamZero OpenPI websocket serving.""" + +from __future__ import annotations + +import os +import subprocess +from pathlib import Path + +import numpy as np +import pytest + +from tests.dreamzero import openpi_client_helper as openpi_client +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServerParams, get_open_port + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" + +MODEL = "GEAR-Dreams/DreamZero-DROID" + + +def _pick_test_gpus() -> str: + override = os.environ.get("DREAMZERO_TEST_GPUS") or os.environ.get("OPENPI_E2E_GPUS") + if override: + return override + + try: + query = subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=index,memory.used", + "--format=csv,noheader,nounits", + ], + text=True, + ) + except Exception: + return "0,1" + + gpu_rows = [] + for line in query.strip().splitlines(): + gpu_index, used_mb = [part.strip() for part in line.split(",", maxsplit=1)] + gpu_rows.append((int(used_mb), gpu_index)) + gpu_rows.sort() + return ",".join(gpu_index for _, gpu_index in gpu_rows[:2]) or "0,1" + + +test_params = [ + OmniServerParams( + model=MODEL, + port=8091, + server_args=[ + "--deploy-config", + "vllm_omni/deploy/dreamzero_tp1_cfg2.yaml", + "--enforce-eager", + "--disable-log-stats", + ], + env_dict={ + "ATTENTION_BACKEND": "torch", + "DIFFUSION_ATTENTION_BACKEND": "TORCH_SDPA", + "VLLM_DISABLE_COMPILE_CACHE": "1", + "CUDA_VISIBLE_DEVICES": _pick_test_gpus(), + "MASTER_PORT": str(get_open_port()), + }, + ) +] + + +def _write_synthetic_video(path: Path, cv2_module, *, channel: int) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + height, width, num_frames = 180, 320, 24 + writer = cv2_module.VideoWriter(str(path), cv2_module.VideoWriter_fourcc(*"mp4v"), 15.0, (width, height)) + if not writer.isOpened(): + raise RuntimeError(f"Failed to open video writer for {path}") + try: + for frame_idx in range(num_frames): + frame = np.zeros((height, width, 3), dtype=np.uint8) + frame[..., channel] = (frame_idx * 7) % 255 + frame[..., (channel + 1) % 3] = 64 + writer.write(cv2_module.cvtColor(frame, cv2_module.COLOR_RGB2BGR)) + finally: + writer.release() + + +def _write_synthetic_dreamzero_videos(client_mod, video_dir: Path) -> None: + for channel, file_name in enumerate(client_mod.CAMERA_FILES.values()): + _write_synthetic_video(video_dir / file_name, client_mod.cv2, channel=channel) + + +@pytest.mark.full_model +@pytest.mark.diffusion +@pytest.mark.distributed_cuda +@hardware_test(res={"cuda": "H100"}, num_cards=2) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_dreamzero_openpi_online(omni_server, tmp_path: Path) -> None: + try: + openpi_client.require_dependencies() + except ModuleNotFoundError as exc: + pytest.skip(str(exc)) + + video_dir = tmp_path / "dreamzero_videos" + _write_synthetic_dreamzero_videos(openpi_client, video_dir) + result = openpi_client.run_policy_session( + host=omni_server.host, + port=omni_server.port, + video_dir=video_dir, + session_id="dreamzero-online-e2e", + ) + + openpi_client.validate_session_result(result) + + metadata = result["metadata"] + assert metadata["needs_session_id"] is True + assert metadata["needs_stereo_camera"] is False + assert tuple(metadata["image_resolution"]) == (180, 320) diff --git a/tests/entrypoints/openai_api/test_openpi_connection.py b/tests/entrypoints/openai_api/test_openpi_connection.py new file mode 100644 index 00000000000..124ae60014d --- /dev/null +++ b/tests/entrypoints/openai_api/test_openpi_connection.py @@ -0,0 +1,270 @@ +import asyncio +import builtins +import sys +import types +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from vllm_omni.entrypoints.openai.realtime.robot import openpi_connection +from vllm_omni.entrypoints.openai.realtime.robot.openpi_serving import PolicyServerConfig + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class FakeWebSocket: + def __init__(self, messages): + self._messages = list(messages) + self.sent_bytes = [] + self.sent_texts = [] + self.accepted = False + self.closed = False + + async def accept(self): + self.accepted = True + + async def send_bytes(self, data): + self.sent_bytes.append(data) + + async def send_text(self, data): + self.sent_texts.append(data) + + async def receive(self): + return self._messages.pop(0) + + async def close(self): + self.closed = True + + +def _serving_mock(): + serving = MagicMock() + serving.policy_server_config = PolicyServerConfig( + { + "image_resolution": (180, 320), + "n_external_cameras": 2, + "needs_wrist_camera": True, + "needs_stereo_camera": False, + "needs_session_id": True, + "action_space": "joint_position", + } + ) + serving.infer = AsyncMock(return_value=[0.0]) + return serving + + +def test_pack_reports_clear_error_when_openpi_client_is_missing(monkeypatch): + real_import = builtins.__import__ + + def import_without_openpi_client(name, globals=None, locals=None, fromlist=(), level=0): + if name == "openpi_client": + raise ModuleNotFoundError("No module named 'openpi_client'", name="openpi_client") + return real_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", import_without_openpi_client) + + with pytest.raises(ImportError) as exc_info: + openpi_connection._pack({"prompt": "pick up the object"}) + + message = str(exc_info.value) + assert "/v1/realtime/robot/openpi" in message + assert "pip install openpi-client" in message + + +def test_pack_and_unpack_delegate_to_openpi_msgpack_numpy(monkeypatch): + calls = [] + + class FakeMsgpackNumpy: + @staticmethod + def packb(obj): + calls.append(("packb", obj)) + return b"packed" + + @staticmethod + def unpackb(data): + calls.append(("unpackb", data)) + return {"unpacked": data} + + fake_openpi_client = types.ModuleType("openpi_client") + fake_openpi_client.msgpack_numpy = FakeMsgpackNumpy + monkeypatch.setitem(sys.modules, "openpi_client", fake_openpi_client) + + assert openpi_connection._pack({"x": 1}) == b"packed" + assert openpi_connection._unpack(b"payload") == {"unpacked": b"payload"} + assert calls == [ + ("packb", {"x": 1}), + ("unpackb", b"payload"), + ] + + +def test_handle_connection_returns_structured_error_for_invalid_payload(monkeypatch): + monkeypatch.setattr(openpi_connection, "_pack", lambda obj: obj) + monkeypatch.setattr( + openpi_connection, + "_unpack", + lambda _data: (_ for _ in ()).throw(ValueError("bad payload traceback")), + ) + + websocket = FakeWebSocket( + [ + {"type": "websocket.receive", "bytes": b"bad"}, + {"type": "websocket.disconnect"}, + ] + ) + serving = MagicMock() + connection = openpi_connection.RobotRealtimeConnection(websocket, serving) + + asyncio.run(connection.handle_connection()) + + assert websocket.accepted is True + assert websocket.sent_bytes[1] == {"type": "error", "message": "Invalid request payload"} + assert "traceback" not in str(websocket.sent_bytes[1]).lower() + assert websocket.sent_texts == [] + serving.infer.assert_not_called() + serving.reset.assert_not_called() + + +def test_handle_connection_rejects_oversized_payload_before_unpack(monkeypatch): + unpack_mock = MagicMock(side_effect=AssertionError("_unpack should not be called")) + monkeypatch.setattr(openpi_connection, "_pack", lambda obj: obj) + monkeypatch.setattr(openpi_connection, "_unpack", unpack_mock) + monkeypatch.setattr(openpi_connection, "MAX_OPENPI_PAYLOAD_BYTES", 4) + + websocket = FakeWebSocket( + [ + {"type": "websocket.receive", "bytes": b"too-large"}, + {"type": "websocket.disconnect"}, + ] + ) + serving = MagicMock() + connection = openpi_connection.RobotRealtimeConnection(websocket, serving) + + asyncio.run(connection.handle_connection()) + + assert websocket.sent_bytes[1] == {"type": "error", "message": "Invalid request payload"} + unpack_mock.assert_not_called() + serving.infer.assert_not_called() + serving.reset.assert_not_called() + + +def test_handle_connection_returns_structured_error_for_infer_exception(monkeypatch): + monkeypatch.setattr(openpi_connection, "_pack", lambda obj: obj) + monkeypatch.setattr( + openpi_connection, + "_unpack", + lambda _data: {"prompt": "pick up the object"}, + ) + + websocket = FakeWebSocket( + [ + {"type": "websocket.receive", "bytes": b"request"}, + {"type": "websocket.disconnect"}, + ] + ) + serving = MagicMock() + serving.infer = AsyncMock(side_effect=RuntimeError("secret traceback text")) + connection = openpi_connection.RobotRealtimeConnection(websocket, serving) + + asyncio.run(connection.handle_connection()) + + assert websocket.sent_bytes[1] == {"type": "error", "message": "Internal inference error"} + assert "secret traceback text" not in str(websocket.sent_bytes[1]) + assert websocket.sent_texts == [] + serving.infer.assert_awaited_once_with( + {"prompt": "pick up the object"}, + session_id="default", + reset=True, + ) + + +def test_handle_connection_closes_websocket_on_idle_timeout(monkeypatch): + monkeypatch.setattr(openpi_connection, "_pack", lambda obj: obj) + + websocket = FakeWebSocket([]) + + async def never_receives(): + await asyncio.sleep(1) + + websocket.receive = never_receives + serving = MagicMock() + serving.policy_server_config = PolicyServerConfig( + { + "image_resolution": (180, 320), + "n_external_cameras": 2, + "needs_wrist_camera": True, + "needs_stereo_camera": False, + "needs_session_id": True, + "action_space": "joint_position", + } + ) + connection = openpi_connection.RobotRealtimeConnection( + websocket, + serving, + idle_timeout=0.01, + ) + + asyncio.run(connection.handle_connection()) + + assert websocket.accepted is True + assert websocket.sent_bytes[0]["action_space"] == "joint_position" + assert websocket.closed is True + assert websocket.sent_texts == [] + serving.infer.assert_not_called() + + +def test_handle_connection_keeps_session_state_per_websocket(monkeypatch): + monkeypatch.setattr(openpi_connection, "_pack", lambda obj: obj) + requests = { + b"a1": {"prompt": "first", "session_id": "session-a"}, + b"a2": {"prompt": "second", "session_id": "session-a"}, + b"b1": {"prompt": "other", "session_id": "session-b"}, + } + monkeypatch.setattr(openpi_connection, "_unpack", lambda data: dict(requests[data])) + serving = _serving_mock() + + websocket_a = FakeWebSocket( + [ + {"type": "websocket.receive", "bytes": b"a1"}, + {"type": "websocket.receive", "bytes": b"a2"}, + {"type": "websocket.disconnect"}, + ] + ) + websocket_b = FakeWebSocket( + [ + {"type": "websocket.receive", "bytes": b"b1"}, + {"type": "websocket.disconnect"}, + ] + ) + + asyncio.run(openpi_connection.RobotRealtimeConnection(websocket_a, serving).handle_connection()) + asyncio.run(openpi_connection.RobotRealtimeConnection(websocket_b, serving).handle_connection()) + + calls = serving.infer.await_args_list + assert calls[0].kwargs == {"session_id": "session-a", "reset": True} + assert calls[1].kwargs == {"session_id": "session-a", "reset": False} + assert calls[2].kwargs == {"session_id": "session-b", "reset": True} + + +def test_handle_connection_reset_endpoint_resets_next_infer(monkeypatch): + monkeypatch.setattr(openpi_connection, "_pack", lambda obj: obj) + requests = { + b"a1": {"prompt": "first", "session_id": "session-a"}, + b"reset": {"endpoint": "reset"}, + b"a2": {"prompt": "second", "session_id": "session-a"}, + } + monkeypatch.setattr(openpi_connection, "_unpack", lambda data: dict(requests[data])) + serving = _serving_mock() + websocket = FakeWebSocket( + [ + {"type": "websocket.receive", "bytes": b"a1"}, + {"type": "websocket.receive", "bytes": b"reset"}, + {"type": "websocket.receive", "bytes": b"a2"}, + {"type": "websocket.disconnect"}, + ] + ) + + asyncio.run(openpi_connection.RobotRealtimeConnection(websocket, serving).handle_connection()) + + assert [call.kwargs["reset"] for call in serving.infer.await_args_list] == [True, True] + serving.reset.assert_called_once_with({}) + assert websocket.sent_bytes[2] == {"status": "reset successful"} + assert websocket.sent_texts == [] diff --git a/tests/entrypoints/openai_api/test_openpi_serving.py b/tests/entrypoints/openai_api/test_openpi_serving.py new file mode 100644 index 00000000000..8f482e55ea7 --- /dev/null +++ b/tests/entrypoints/openai_api/test_openpi_serving.py @@ -0,0 +1,285 @@ +import asyncio +import pickle +import threading +from concurrent.futures import ThreadPoolExecutor +from types import SimpleNamespace + +import numpy as np +import pytest +from fastapi import FastAPI, WebSocket +from omegaconf import OmegaConf +from starlette.testclient import TestClient + +from vllm_omni.diffusion.models.dreamzero import transform as dreamzero_transform +from vllm_omni.entrypoints.openai.realtime.robot import openpi_connection, openpi_serving + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + +TEST_POLICY_SERVER_CONFIG = { + "image_resolution": (180, 320), + "n_external_cameras": 2, + "needs_wrist_camera": True, + "needs_stereo_camera": False, + "needs_session_id": True, + "action_space": "joint_position", +} + + +def _engine_with_policy_config(policy_config=None): + od_config = SimpleNamespace(model_config={"policy_server_config": policy_config or TEST_POLICY_SERVER_CONFIG}) + return SimpleNamespace(get_diffusion_od_config=lambda: od_config) + + +class RecordingEngine: + def __init__(self): + self.od_config = SimpleNamespace(model_config={"policy_server_config": TEST_POLICY_SERVER_CONFIG}) + self.generate_calls = [] + + def get_diffusion_od_config(self): + return self.od_config + + def generate(self, *, prompt, request_id, sampling_params_list): + async def _generate(): + self.generate_calls.append( + { + "prompt": prompt, + "request_id": request_id, + "sampling_params_list": sampling_params_list, + } + ) + yield SimpleNamespace(multimodal_output={"actions": [0.0]}) + + return _generate() + + +class ConcurrentRecordingEngine(RecordingEngine): + def __init__(self, *, expected_calls: int): + super().__init__() + self.expected_calls = expected_calls + self.condition = threading.Condition() + self.saw_overlap = False + + def _wait_for_expected_calls(self): + with self.condition: + completed = self.condition.wait_for( + lambda: len(self.generate_calls) >= self.expected_calls, + timeout=5.0, + ) + self.saw_overlap = self.saw_overlap or completed + + def generate(self, *, prompt, request_id, sampling_params_list): + async def _generate(): + with self.condition: + self.generate_calls.append( + { + "prompt": prompt, + "request_id": request_id, + "sampling_params_list": sampling_params_list, + } + ) + if len(self.generate_calls) >= self.expected_calls: + self.saw_overlap = True + self.condition.notify_all() + + await asyncio.to_thread(self._wait_for_expected_calls) + yield SimpleNamespace(multimodal_output={"actions": [0.0]}) + + return _generate() + + +def test_ensure_transforms_loaded_fails_fast_on_import_error(monkeypatch): + def fail_import(module_name): + raise ModuleNotFoundError(f"missing module: {module_name}") + + monkeypatch.setattr(dreamzero_transform.importlib, "import_module", fail_import) + + with pytest.raises(RuntimeError) as exc_info: + dreamzero_transform.ensure_transforms_loaded() + + assert "Failed to import DreamZero transform module" in str(exc_info.value) + + +def test_ensure_transforms_loaded_fails_when_default_transform_missing(monkeypatch): + monkeypatch.setattr(dreamzero_transform.importlib, "import_module", lambda _module_name: None) + monkeypatch.setattr(dreamzero_transform, "TRANSFORMS", {}) + + with pytest.raises(RuntimeError) as exc_info: + dreamzero_transform.ensure_transforms_loaded() + + assert "roboarena" in str(exc_info.value) + assert "not registered" in str(exc_info.value) + + +def test_policy_server_config_reads_diffusion_model_config(): + policy_config = { + "image_resolution": [64, 64], + "n_external_cameras": 1, + "custom_model_key": {"nested": True}, + } + od_config = SimpleNamespace(model_config={"policy_server_config": policy_config}) + engine_client = SimpleNamespace(get_diffusion_od_config=lambda: od_config) + + serving = openpi_serving.ServingRealtimeRobotOpenPI(engine_client=engine_client) + + assert serving.policy_server_config.to_dict() == policy_config + + +def test_policy_server_config_reads_stage_config_model_config(): + policy_config = {"custom_model_key": "from-stage-config"} + engine_client = SimpleNamespace( + get_diffusion_od_config=lambda: None, + stage_configs=[ + SimpleNamespace( + stage_type="diffusion", + engine_args=SimpleNamespace(model_config={"policy_server_config": policy_config}), + ) + ], + ) + + serving = openpi_serving.ServingRealtimeRobotOpenPI(engine_client=engine_client) + + assert serving.policy_server_config.to_dict() == policy_config + + +def test_policy_server_config_reads_omegaconf_stage_config(): + engine_client = SimpleNamespace( + get_diffusion_od_config=lambda: None, + stage_configs=[ + SimpleNamespace( + stage_type="diffusion", + engine_args=SimpleNamespace( + model_config=OmegaConf.create({"policy_server_config": {"custom_model_key": "from-omegaconf"}}) + ), + ) + ], + ) + + serving = openpi_serving.ServingRealtimeRobotOpenPI(engine_client=engine_client) + + assert serving.policy_server_config.to_dict() == {"custom_model_key": "from-omegaconf"} + + +def test_policy_server_config_is_required(): + od_config = SimpleNamespace(model_config={}) + engine_client = SimpleNamespace(get_diffusion_od_config=lambda: od_config) + + with pytest.raises(ValueError) as exc_info: + openpi_serving.ServingRealtimeRobotOpenPI(engine_client=engine_client) + + assert "policy_server_config" in str(exc_info.value) + + +def test_create_policy_server_returns_none_without_policy_config(): + od_config = SimpleNamespace(model_config={}) + engine_client = SimpleNamespace(get_diffusion_od_config=lambda: od_config) + + serving = openpi_serving.ServingRealtimeRobotOpenPI.create_policy_server( + engine_client=engine_client, + model_name="k2-fsa/OmniVoice", + ) + + assert serving is None + + +def test_policy_server_config_reads_engine_model_config(): + policy_config = {"custom_model_key": "custom-value"} + engine_client = SimpleNamespace(model_config=SimpleNamespace(policy_server_config=policy_config)) + + serving = openpi_serving.ServingRealtimeRobotOpenPI(engine_client=engine_client) + + assert serving.policy_server_config.to_dict() == policy_config + + +def test_build_request_uses_unique_engine_request_id_per_inference(): + serving = openpi_serving.ServingRealtimeRobotOpenPI(engine_client=_engine_with_policy_config()) + + request_a = serving._build_request( + {"prompt": "pick up the object"}, + session_id="session-a", + reset=True, + ) + request_b = serving._build_request( + {"prompt": "pick up the object"}, + session_id="session-a", + reset=False, + ) + + assert request_a.sampling_params.extra_args["reset"] is True + assert request_b.sampling_params.extra_args["reset"] is False + assert request_a.sampling_params.extra_args["session_id"] == "session-a" + assert request_b.sampling_params.extra_args["session_id"] == "session-a" + assert request_a.sampling_params.extra_args["robot_obs"]["prompt"] == "pick up the object" + assert request_b.sampling_params.extra_args["robot_obs"]["prompt"] == "pick up the object" + + assert request_a.request_ids == ["robot-session-a-0"] + assert request_b.request_ids == ["robot-session-a-1"] + assert request_a.request_ids[0] != request_b.request_ids[0] + + +def test_infer_keeps_session_state_but_uses_unique_engine_request_ids(): + engine = RecordingEngine() + serving = openpi_serving.ServingRealtimeRobotOpenPI(engine_client=engine) + + async def run_requests(): + await serving.infer({"prompt": "pick up the object"}, session_id="session-a", reset=True) + await serving.infer({"prompt": "pick up the object"}, session_id="session-a", reset=False) + + asyncio.run(run_requests()) + + assert [call["request_id"] for call in engine.generate_calls] == [ + "robot-session-a-0", + "robot-session-a-1", + ] + assert engine.generate_calls[0]["request_id"] != engine.generate_calls[1]["request_id"] + + sampling_params_a = engine.generate_calls[0]["sampling_params_list"][0] + sampling_params_b = engine.generate_calls[1]["sampling_params_list"][0] + assert sampling_params_a.extra_args["session_id"] == "session-a" + assert sampling_params_b.extra_args["session_id"] == "session-a" + assert sampling_params_a.extra_args["reset"] is True + assert sampling_params_b.extra_args["reset"] is False + + +def test_two_websocket_clients_without_session_id_do_not_conflict(monkeypatch): + monkeypatch.setattr(openpi_connection, "_pack", pickle.dumps) + monkeypatch.setattr(openpi_connection, "_unpack", pickle.loads) + + engine = ConcurrentRecordingEngine(expected_calls=2) + serving = openpi_serving.ServingRealtimeRobotOpenPI(engine_client=engine) + app = FastAPI() + + @app.websocket("/v1/realtime/robot/openpi") + async def openpi_endpoint(websocket: WebSocket): + connection = openpi_connection.RobotRealtimeConnection(websocket, serving) + await connection.handle_connection() + + def run_client(prompt: str): + with TestClient(app) as client: + with client.websocket_connect("/v1/realtime/robot/openpi") as websocket: + metadata = pickle.loads(websocket.receive_bytes()) + assert metadata["needs_session_id"] is True + + websocket.send_bytes(pickle.dumps({"prompt": prompt})) + actions = pickle.loads(websocket.receive_bytes()) + np.testing.assert_array_equal( + np.asarray(actions, dtype=np.float32), + np.asarray([0.0], dtype=np.float32), + ) + + with ThreadPoolExecutor(max_workers=2) as executor: + futures = [ + executor.submit(run_client, "first client"), + executor.submit(run_client, "second client"), + ] + for future in futures: + future.result(timeout=10.0) + + request_ids = [call["request_id"] for call in engine.generate_calls] + assert len(request_ids) == 2 + assert len(set(request_ids)) == 2 + assert all(request_id.startswith("robot-default-") for request_id in request_ids) + assert engine.saw_overlap is True + + sampling_params = [call["sampling_params_list"][0] for call in engine.generate_calls] + assert [params.extra_args["session_id"] for params in sampling_params] == ["default", "default"] + assert [params.extra_args["reset"] for params in sampling_params] == [True, True] diff --git a/tests/entrypoints/test_omni_entrypoints.py b/tests/entrypoints/test_omni_entrypoints.py index adcdc3e9780..7cf1f696a05 100644 --- a/tests/entrypoints/test_omni_entrypoints.py +++ b/tests/entrypoints/test_omni_entrypoints.py @@ -482,6 +482,19 @@ def test_openai_serving_models_can_consume_async_omni_compat_attrs(): def test_get_diffusion_od_config_returns_diffusion_stage_config(): + diffusion_od_config = object() + omni = object.__new__(AsyncOmni) + omni.engine = SimpleNamespace( + stage_clients=[ + SimpleNamespace(stage_type="llm"), + SimpleNamespace(stage_type="diffusion", od_config=diffusion_od_config), + ] + ) + + assert omni.get_diffusion_od_config() is diffusion_od_config + + +def test_get_diffusion_od_config_falls_back_to_inner_engine(): diffusion_od_config = object() omni = object.__new__(AsyncOmni) omni.engine = SimpleNamespace( diff --git a/tests/entrypoints/test_resolve_dreamzero_config.py b/tests/entrypoints/test_resolve_dreamzero_config.py new file mode 100644 index 00000000000..a962c5262c1 --- /dev/null +++ b/tests/entrypoints/test_resolve_dreamzero_config.py @@ -0,0 +1,56 @@ +import pytest + +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.stage_diffusion_proc import StageDiffusionProc +from vllm_omni.entrypoints.utils import load_stage_configs_from_model, resolve_model_config_path + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def test_dreamzero_vla_resolves_to_dreamzero_config(monkeypatch): + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.get_config", + lambda _model, trust_remote_code=True: type("Cfg", (), {"model_type": "vla"})(), + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.utils._looks_like_dreamzero", + lambda _model: True, + ) + result = resolve_model_config_path("GEAR-Dreams/DreamZero-DROID") + + assert result is not None + assert result.endswith("vllm_omni/deploy/dreamzero.yaml") + + +def test_dreamzero_config_sets_model_class_and_policy_config(monkeypatch): + monkeypatch.setattr( + "vllm_omni.config.stage_config.StageConfigFactory._auto_detect_model_type", + classmethod(lambda _cls, _model, trust_remote_code=True: ("vla", None)), + ) + monkeypatch.setattr( + "vllm_omni.diffusion.utils.hf_utils._looks_like_dreamzero", + lambda _model: True, + ) + + stage_configs = load_stage_configs_from_model("GEAR-Dreams/DreamZero-DROID") + engine_args = stage_configs[0].engine_args + + assert engine_args.model_class_name == "DreamZeroPipeline" + assert engine_args.model_config.policy_server_config.action_space == "joint_position" + + +def test_dreamzero_enrich_config_preserves_explicit_model_class_name(monkeypatch): + monkeypatch.setattr( + "vllm.transformers_utils.config.get_hf_file_to_dict", + lambda path, _model: None if path == "model_index.json" else {"model_type": "vla", "architectures": ["VLA"]}, + ) + + od_config = OmniDiffusionConfig( + model="GEAR-Dreams/DreamZero-DROID", + model_class_name="DreamZeroPipeline", + ) + proc = StageDiffusionProc(od_config.model, od_config) + + proc._enrich_config() + + assert od_config.model_class_name == "DreamZeroPipeline" diff --git a/third_party/dreamzero b/third_party/dreamzero new file mode 160000 index 00000000000..d70a8025ac7 --- /dev/null +++ b/third_party/dreamzero @@ -0,0 +1 @@ +Subproject commit d70a8025ac77f7486f38032c718a1ca814a4d8c7 diff --git a/third_party/lerobot b/third_party/lerobot new file mode 160000 index 00000000000..017ff73fbfe --- /dev/null +++ b/third_party/lerobot @@ -0,0 +1 @@ +Subproject commit 017ff73fbfe46bf9a673cd9b402988dcb79151f7 diff --git a/third_party/openpi b/third_party/openpi new file mode 160000 index 00000000000..54cbaee6ae0 --- /dev/null +++ b/third_party/openpi @@ -0,0 +1 @@ +Subproject commit 54cbaee6ae0c010a1ed431871cdaa8f4684ac709 diff --git a/vllm_omni/config/pipeline_registry.py b/vllm_omni/config/pipeline_registry.py index 555f35e173a..c88dcfd5ddd 100644 --- a/vllm_omni/config/pipeline_registry.py +++ b/vllm_omni/config/pipeline_registry.py @@ -17,11 +17,10 @@ ``vllm_omni/.../pipeline.py``. 2. Add one line to ``_OMNI_PIPELINES`` below. -Single-stage diffusion models continue to use the -``_create_default_diffusion_stage_cfg`` fallback in -``async_omni_engine.py`` — they don't need a registry entry. The empty -``_DIFFUSION_PIPELINES`` placeholder previously here (#2915) was removed -once #2987 (which would have populated it) was deferred. +Plain single-stage diffusion models continue to use the +``_create_default_diffusion_stage_cfg`` fallback in ``async_omni_engine.py``. +The empty ``_DIFFUSION_PIPELINES`` placeholder previously here (#2915) was +removed once #2987 (which would have populated it) was deferred. ``register_pipeline(config)`` in ``stage_config`` is still supported for out-of-tree plugins and tests that create pipelines at runtime; those override @@ -65,6 +64,10 @@ "vllm_omni.model_executor.models.bagel.pipeline", "BAGEL_SINGLE_STAGE_PIPELINE", ), + "dreamzero": ( + "vllm_omni.model_executor.models.dreamzero.pipeline", + "DREAMZERO_PIPELINE", + ), "glm_image": ( "vllm_omni.model_executor.models.glm_image.pipeline", "GLM_IMAGE_PIPELINE", diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py index c459ecabe73..aa142a9ffed 100644 --- a/vllm_omni/config/stage_config.py +++ b/vllm_omni/config/stage_config.py @@ -1093,6 +1093,11 @@ def create_from_model( # --- New path: check pipeline registry by model_type first --- model_type, hf_config = cls._auto_detect_model_type(model, trust_remote_code=trust_remote_code) + if model_type == "vla": + from vllm_omni.diffusion.utils.hf_utils import _looks_like_dreamzero + + if _looks_like_dreamzero(model): + model_type = "dreamzero" if model_type and model_type in _PIPELINE_REGISTRY: return cls._create_from_registry(model_type, cli_overrides, deploy_config_path) diff --git a/vllm_omni/deploy/dreamzero.yaml b/vllm_omni/deploy/dreamzero.yaml new file mode 100644 index 00000000000..e77afad8409 --- /dev/null +++ b/vllm_omni/deploy/dreamzero.yaml @@ -0,0 +1,25 @@ +# DreamZero-DROID deploy: single diffusion stage. +# +# Topology is declared in vllm_omni/model_executor/models/dreamzero/pipeline.py. +# This default uses one GPU with TP=1 and CFG parallel disabled. + +pipeline: dreamzero +async_chunk: false +distributed_executor_backend: mp +dtype: bfloat16 + +stages: + - stage_id: 0 + devices: "0" + max_num_seqs: 1 + enforce_eager: true + model_class_name: DreamZeroPipeline + model_config: + default_robot_embodiment: roboarena + policy_server_config: + image_resolution: [180, 320] + n_external_cameras: 2 + needs_wrist_camera: true + needs_stereo_camera: false + needs_session_id: true + action_space: joint_position diff --git a/vllm_omni/deploy/dreamzero_tp1_cfg2.yaml b/vllm_omni/deploy/dreamzero_tp1_cfg2.yaml new file mode 100644 index 00000000000..ed9988c92a3 --- /dev/null +++ b/vllm_omni/deploy/dreamzero_tp1_cfg2.yaml @@ -0,0 +1,25 @@ +# DreamZero-DROID deploy: TP=1, CFG parallel size=2. + +pipeline: dreamzero +async_chunk: false +distributed_executor_backend: mp +dtype: bfloat16 + +stages: + - stage_id: 0 + devices: "0,1" + max_num_seqs: 1 + enforce_eager: true + model_class_name: DreamZeroPipeline + parallel_config: + tensor_parallel_size: 1 + cfg_parallel_size: 2 + model_config: + default_robot_embodiment: roboarena + policy_server_config: + image_resolution: [180, 320] + n_external_cameras: 2 + needs_wrist_camera: true + needs_stereo_camera: false + needs_session_id: true + action_space: joint_position diff --git a/vllm_omni/deploy/dreamzero_tp2_cfg1.yaml b/vllm_omni/deploy/dreamzero_tp2_cfg1.yaml new file mode 100644 index 00000000000..16f0aa7ef21 --- /dev/null +++ b/vllm_omni/deploy/dreamzero_tp2_cfg1.yaml @@ -0,0 +1,25 @@ +# DreamZero-DROID deploy: TP=2, CFG parallel disabled. + +pipeline: dreamzero +async_chunk: false +distributed_executor_backend: mp +dtype: bfloat16 + +stages: + - stage_id: 0 + devices: "0,1" + max_num_seqs: 1 + enforce_eager: true + model_class_name: DreamZeroPipeline + parallel_config: + tensor_parallel_size: 2 + cfg_parallel_size: 1 + model_config: + default_robot_embodiment: roboarena + policy_server_config: + image_resolution: [180, 320] + n_external_cameras: 2 + needs_wrist_camera: true + needs_stereo_camera: false + needs_session_id: true + action_space: joint_position diff --git a/vllm_omni/deploy/dreamzero_tp2_cfg2.yaml b/vllm_omni/deploy/dreamzero_tp2_cfg2.yaml new file mode 100644 index 00000000000..76e57a8ca49 --- /dev/null +++ b/vllm_omni/deploy/dreamzero_tp2_cfg2.yaml @@ -0,0 +1,25 @@ +# DreamZero-DROID deploy: TP=2, CFG parallel size=2. + +pipeline: dreamzero +async_chunk: false +distributed_executor_backend: mp +dtype: bfloat16 + +stages: + - stage_id: 0 + devices: "0,1,2,3" + max_num_seqs: 1 + enforce_eager: true + model_class_name: DreamZeroPipeline + parallel_config: + tensor_parallel_size: 2 + cfg_parallel_size: 2 + model_config: + default_robot_embodiment: roboarena + policy_server_config: + image_resolution: [180, 320] + n_external_cameras: 2 + needs_wrist_camera: true + needs_stereo_camera: false + needs_session_id: true + action_space: joint_position diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index a6fe1e4e9c7..3dfb6c7c617 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -901,8 +901,24 @@ def enrich_config(self) -> None: self.model_class_name = "WanS2VPipeline" self.tf_model_config = TransformerConfig() self.update_multimodal_support() + elif model_type == "vla": + from vllm_omni.diffusion.utils.hf_utils import _looks_like_dreamzero + + if _looks_like_dreamzero(self.model): + self.model_class_name = "DreamZeroPipeline" + self.set_tf_model_config(TransformerConfig()) + self.update_multimodal_support() + else: + raise elif architectures and len(architectures) == 1: - self.model_class_name = architectures[0] + architecture = architectures[0] + from vllm_omni.diffusion.registry import DiffusionModelRegistry + + if ( + self.model_class_name is None + or DiffusionModelRegistry._try_load_model_cls(architecture) is not None + ): + self.model_class_name = architecture else: raise @@ -972,10 +988,10 @@ class DiffusionOutput: """ # Fields may be replaced with SHM handle dicts by ipc.pack_diffusion_output_shm - output: torch.Tensor | dict | None = None - trajectory_timesteps: torch.Tensor | dict | None = None - trajectory_latents: torch.Tensor | dict | None = None - trajectory_log_probs: torch.Tensor | dict | None = None + output: torch.Tensor | tuple[Any, ...] | dict[str, Any] | None = None + trajectory_timesteps: torch.Tensor | dict[str, Any] | None = None + trajectory_latents: torch.Tensor | dict[str, Any] | None = None + trajectory_log_probs: torch.Tensor | dict[str, Any] | None = None trajectory_decoded: list[Image.Image] | None = None error: str | None = None aborted: bool = False diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index c13bd3c0c37..d2624c65144 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -251,8 +251,10 @@ async def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: custom_output = output.custom_output or {} model_audio_sample_rate = None model_fps = None + action_payload = None if isinstance(outputs, dict): audio_payload = outputs.get("audio") + action_payload = outputs.get("actions") custom_output.update(outputs.get("custom_output") or {}) model_audio_sample_rate = outputs.get("audio_sample_rate") model_fps = outputs.get("fps") @@ -346,6 +348,8 @@ def _audio_mm(payload: Any) -> dict[str, Any]: mm_output["audio_sample_rate"] = model_audio_sample_rate if model_fps is not None: mm_output["fps"] = model_fps + if action_payload is not None: + mm_output["actions"] = action_payload return [ OmniRequestOutput.from_diffusion( request_id=request_id, @@ -416,6 +420,18 @@ def _audio_mm(payload: Any) -> dict[str, Any]: mm_output["audio_sample_rate"] = model_audio_sample_rate if model_fps is not None: mm_output["fps"] = model_fps + if action_payload is not None: + sliced_actions = action_payload + if isinstance(action_payload, (list, tuple)): + sliced_actions = action_payload[start_idx:end_idx] + if len(sliced_actions) == 1: + sliced_actions = sliced_actions[0] + elif hasattr(action_payload, "shape") and getattr(action_payload, "shape", None) is not None: + if len(action_payload.shape) > 0 and action_payload.shape[0] >= end_idx: + sliced_actions = action_payload[start_idx:end_idx] + if num_outputs == 1: + sliced_actions = sliced_actions[0] + mm_output["actions"] = sliced_actions results.append( OmniRequestOutput.from_diffusion( request_id=request_id, diff --git a/vllm_omni/diffusion/models/dreamzero/__init__.py b/vllm_omni/diffusion/models/dreamzero/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/vllm_omni/diffusion/models/dreamzero/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/diffusion/models/dreamzero/action_encoder.py b/vllm_omni/diffusion/models/dreamzero/action_encoder.py new file mode 100644 index 00000000000..b35bec436ac --- /dev/null +++ b/vllm_omni/diffusion/models/dreamzero/action_encoder.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Action encoder/decoder for DreamZero.""" + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def swish(x: torch.Tensor) -> torch.Tensor: + """swish activation: x * sigmoid(x)""" + return x * torch.sigmoid(x) + + +class SinusoidalPositionalEncoding(nn.Module): + """Sinusoidal encoding: (B, T) timesteps → (B, T, dim)""" + + def __init__(self, embedding_dim: int) -> None: + super().__init__() + self.embedding_dim = embedding_dim + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + timesteps = timesteps.float() + half_dim = self.embedding_dim // 2 + exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * ( + torch.log(torch.tensor(10000.0)) / half_dim + ) + freqs = timesteps.unsqueeze(-1) * exponent.exp() + return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1) + + +class CategorySpecificLinear(nn.Module): + """Per-category linear: W[cat_id] @ x + b[cat_id] + + Attributes: + W: (num_categories, input_dim, hidden_dim) — note: 0.02 * randn init + b: (num_categories, hidden_dim) — zero init + """ + + def __init__(self, num_categories: int, input_dim: int, hidden_dim: int) -> None: + super().__init__() + self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim)) + self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim)) + + def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor: + selected_W = self.W[cat_ids] + selected_b = self.b[cat_ids] + return torch.bmm(x, selected_W) + selected_b.unsqueeze(1) + + +class CategorySpecificMLP(nn.Module): + """Two-layer MLP: layer1 (relu) → layer2""" + + def __init__(self, num_categories: int, input_dim: int, hidden_dim: int, output_dim: int) -> None: + super().__init__() + self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim) + self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim) + + def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor: + hidden = F.relu(self.layer1(x, cat_ids)) + return self.layer2(hidden, cat_ids) + + +class MultiEmbodimentActionEncoder(nn.Module): + """Encode actions with embodiment-specific weights + sinusoidal timestep. + + Flow: actions → W1 → concat(a_emb, pos_enc(timesteps)) → W2 (swish) → W3 + + Args: + action_dim: action vector dimension (e.g. 32) + hidden_size: output/hidden dimension (e.g. 5120 = model dim) + num_embodiments: number of robot types (e.g. 32) + """ + + def __init__(self, action_dim: int, hidden_size: int, num_embodiments: int) -> None: + super().__init__() + self.hidden_size = hidden_size + self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) + self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) + self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) + self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) + + def forward(self, actions: torch.Tensor, timesteps: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor: + """ + Args: + actions: (B, T, action_dim) + timesteps: (B, T) — per-token timestep + cat_ids: (B,) — embodiment id per sample + Returns: + (B, T, hidden_size) + """ + a_emb = self.W1(actions, cat_ids) + tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype) + x = torch.cat([a_emb, tau_emb], dim=-1) + x = swish(self.W2(x, cat_ids)) + x = self.W3(x, cat_ids) + return x diff --git a/vllm_omni/diffusion/models/dreamzero/causal_wan_model.py b/vllm_omni/diffusion/models/dreamzero/causal_wan_model.py new file mode 100644 index 00000000000..37598ae0c9b --- /dev/null +++ b/vllm_omni/diffusion/models/dreamzero/causal_wan_model.py @@ -0,0 +1,954 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""CausalWanModel — 40-layer DiT with causal attention and KV cache. + +Key differences from WanTransformer3DModel: +- Causal self-attention (new frames only see history) +- KV cache for streaming inference +- Action/state token support (appended after video tokens) +- Extended RoPE with action/state-specific frequencies +- Inference-only forward with KV cache +""" + +from __future__ import annotations + +import math +from typing import Any + +import torch +import torch.nn as nn +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, +) +from vllm.model_executor.layers.conv import Conv3dLayer +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.utils import set_weight_attrs + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.models.dreamzero.action_encoder import ( + CategorySpecificMLP, + MultiEmbodimentActionEncoder, +) + +# ── RoPE utilities ────────────────────────────────────────────────── + + +def sinusoidal_embedding_1d(dim: int, position: torch.Tensor) -> torch.Tensor: + """Sinusoidal positional embedding for timesteps.""" + if dim % 2 != 0: + raise ValueError(f"dim must be even, got {dim}.") + half = dim // 2 + position = position.type(torch.float64) + sinusoid = torch.outer( + position, + torch.pow(10000, -torch.arange(half, dtype=position.dtype, device=position.device).div(half)), + ) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +def rope_params(max_seq_len: int, dim: int) -> torch.Tensor: + """Precompute complex-valued RoPE frequencies (polar form). + Returns: complex tensor [max_seq_len, dim // 2] + """ + if dim % 2 != 0: + raise ValueError(f"dim must be even, got {dim}.") + freqs = torch.outer( + torch.arange(max_seq_len), + 1.0 / torch.pow(10000, torch.arange(0, dim, 2).to(torch.float64).div(dim)), + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +def rope_apply(x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + """Apply RoPE to x using precomputed complex freqs.""" + B, seq_len, n, _ = x.shape + x = torch.view_as_complex(x.to(torch.float64).reshape(B, seq_len, n, -1, 2)) + freqs = freqs.unsqueeze(0) + x = torch.view_as_real(x * freqs).flatten(3) + return x + + +def rope_action_apply( + x: torch.Tensor, + freqs: torch.Tensor, + freqs_action: torch.Tensor, + freqs_state: torch.Tensor, + action_register_length: int | None, + num_action_per_block: int = 32, + num_state_per_block: int = 1, +) -> torch.Tensor: + """RoPE with action/state frequency tables for multi-step sequences.""" + B, seq_len, n, _ = x.shape + x = torch.view_as_complex(x.to(torch.float64).reshape(B, seq_len, n, -1, 2)) + if action_register_length is not None: + if num_action_per_block is None: + raise ValueError("num_action_per_block is required when action_register_length is set.") + if num_state_per_block is None: + raise ValueError("num_state_per_block is required when action_register_length is set.") + chunk_size = action_register_length // (num_action_per_block + num_state_per_block) + freqs_1d_action = freqs_action[: chunk_size * num_action_per_block].view( + chunk_size * num_action_per_block, 1, -1 + ) + freqs_1d_state = freqs_state[: chunk_size * num_state_per_block].view(chunk_size * num_state_per_block, 1, -1) + freqs = torch.cat([freqs, freqs_1d_action, freqs_1d_state], dim=0) + freqs = freqs.unsqueeze(0) + x = torch.view_as_real(x * freqs).flatten(3) + return x + + +def causal_rope_action_apply( + x: torch.Tensor, + freqs: torch.Tensor, + freqs_action: torch.Tensor, + freqs_state: torch.Tensor, + action_register_length: int | None, + num_action_per_block: int, + num_state_per_block: int, + action_state_index: int, +) -> torch.Tensor: + """RoPE for single inference step (causal / KV-cache mode).""" + B, seq_len, n, _ = x.shape + x = torch.view_as_complex(x.to(torch.float64).reshape(B, seq_len, n, -1, 2)) + if action_register_length is not None: + expected_length = num_action_per_block + num_state_per_block + if action_register_length != expected_length: + raise ValueError( + f"action_register_length must equal num_action_per_block + num_state_per_block " + f"({expected_length}), got {action_register_length}." + ) + freqs_action = freqs_action[ + action_state_index * num_action_per_block : (action_state_index + 1) * num_action_per_block + ] + freqs_state = freqs_state[ + action_state_index * num_state_per_block : (action_state_index + 1) * num_state_per_block + ] + freqs_1d = torch.cat([freqs_action, freqs_state], dim=0).view(action_register_length, 1, -1) + freqs = torch.cat([freqs, freqs_1d], dim=0) + freqs = freqs.unsqueeze(0) + x = torch.view_as_real(x * freqs).flatten(3) + return x + + +# ── Normalization ─────────────────────────────────────────────────── + + +class WanLayerNorm(nn.LayerNorm): + """LayerNorm wrapper used by DreamZero blocks.""" + + def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False) -> None: + super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine) + + +class DistributedRMSNorm(nn.Module): + """RMSNorm that computes global RMS across tensor parallel ranks.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) + + def weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + if param.shape == loaded_weight.shape: + param.data.copy_(loaded_weight) + return + + tp_size = get_tensor_model_parallel_world_size() + if loaded_weight.shape[0] % tp_size != 0: + raise ValueError( + f"Cannot shard RMSNorm weight of shape {tuple(loaded_weight.shape)} across tp_size={tp_size}." + ) + + shard_size = loaded_weight.shape[0] // tp_size + start_idx = get_tensor_model_parallel_rank() * shard_size + shard = loaded_weight.narrow(0, start_idx, shard_size) + if param.shape != shard.shape: + raise ValueError(f"RMSNorm shard shape mismatch: param={tuple(param.shape)}, shard={tuple(shard.shape)}.") + param.data.copy_(shard) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + tp_size = get_tensor_model_parallel_world_size() + x_float = x.float() + local_sum_sq = x_float.pow(2).sum(dim=-1, keepdim=True) + local_count = x.shape[-1] + + if tp_size > 1: + global_sum_sq = local_sum_sq.clone() + torch.distributed.all_reduce(global_sum_sq, group=get_tp_group().device_group) + global_count = local_count * tp_size + else: + global_sum_sq = local_sum_sq + global_count = local_count + + mean_sq = global_sum_sq / global_count + return (x_float * torch.rsqrt(mean_sq + self.eps)).type_as(x) * self.weight + + +# ── Projections ───────────────────────────────────────────────────── + + +class MLPProj(nn.Module): + """CLIP feature projection for i2v. + Uses ColumnParallelLinear + RowParallelLinear (Qwen3_VisionMLP pattern). + """ + + def __init__(self, in_dim: int, out_dim: int) -> None: + super().__init__() + self.norm1 = nn.LayerNorm(in_dim) + self.fc1 = ColumnParallelLinear( + in_dim, + in_dim, + bias=True, + return_bias=False, + ) + self.act = nn.GELU() + self.fc2 = RowParallelLinear( + in_dim, + out_dim, + bias=True, + return_bias=False, + ) + self.norm2 = nn.LayerNorm(out_dim) + + def forward(self, image_embeds: torch.Tensor) -> torch.Tensor: + x = self.norm1(image_embeds) + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.norm2(x) + return x + + +# ── Cross-Attention ───────────────────────────────────────────────── +# T2V and I2V cross-attention variants + + +class WanT2VCrossAttention(nn.Module): + """Text-to-video cross-attention. + Uses vllm-omni Attention for FlashAttn backend. + """ + + def __init__(self, dim: int, num_heads: int, window_size=(-1, -1), qk_norm: bool = True, eps: float = 1e-6) -> None: + super().__init__() + if dim % num_heads != 0: + raise ValueError(f"dim={dim} must be divisible by num_heads={num_heads}.") + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + tp_size = get_tensor_model_parallel_world_size() + if num_heads % tp_size != 0: + raise ValueError(f"num_heads={num_heads} must be divisible by tp_size={tp_size}.") + self.tp_num_heads = num_heads // tp_size + self.tp_inner_dim = self.tp_num_heads * self.head_dim + self.q = ColumnParallelLinear(dim, dim, bias=True, gather_output=False, return_bias=False) + self.k = ColumnParallelLinear(dim, dim, bias=True, gather_output=False, return_bias=False) + self.v = ColumnParallelLinear(dim, dim, bias=True, gather_output=False, return_bias=False) + self.o = RowParallelLinear(dim, dim, bias=True, input_is_parallel=True, return_bias=False) + self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps) if qk_norm else nn.Identity() + self.attn = Attention( + self.tp_num_heads, + self.head_dim, + causal=False, + softmax_scale=self.head_dim**-0.5, + skip_sequence_parallel=True, + ) + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor, + context_lens: torch.Tensor | None = None, + crossattn_cache: dict | None = None, + ) -> torch.Tensor: + del context_lens + n, d = self.tp_num_heads, self.head_dim + q = self.norm_q(self.q(x)).unflatten(2, (n, d)) + if crossattn_cache is not None: + if not crossattn_cache["is_init"]: + crossattn_cache["is_init"] = True + k = self.norm_k(self.k(context)).unflatten(2, (n, d)) + v = self.v(context).unflatten(2, (n, d)) + crossattn_cache["k"] = k + crossattn_cache["v"] = v + else: + k = crossattn_cache["k"] + v = crossattn_cache["v"] + else: + k = self.norm_k(self.k(context)).unflatten(2, (n, d)) + v = self.v(context).unflatten(2, (n, d)) + x = self.attn(q, k, v) + x = x.flatten(2) + x = self.o(x) + return x + + +class WanI2VCrossAttention(nn.Module): + """Image-to-video cross-attention (splits first 257 image tokens). + Uses vllm-omni Attention for FlashAttn backend. + """ + + def __init__(self, dim: int, num_heads: int, window_size=(-1, -1), qk_norm: bool = True, eps: float = 1e-6) -> None: + super().__init__() + if dim % num_heads != 0: + raise ValueError(f"dim={dim} must be divisible by num_heads={num_heads}.") + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + tp_size = get_tensor_model_parallel_world_size() + if num_heads % tp_size != 0: + raise ValueError(f"num_heads={num_heads} must be divisible by tp_size={tp_size}.") + self.tp_num_heads = num_heads // tp_size + self.tp_inner_dim = self.tp_num_heads * self.head_dim + self.q = ColumnParallelLinear(dim, dim, bias=True, gather_output=False, return_bias=False) + self.k = ColumnParallelLinear(dim, dim, bias=True, gather_output=False, return_bias=False) + self.v = ColumnParallelLinear(dim, dim, bias=True, gather_output=False, return_bias=False) + self.o = RowParallelLinear(dim, dim, bias=True, input_is_parallel=True, return_bias=False) + self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps) if qk_norm else nn.Identity() + self.k_img = ColumnParallelLinear(dim, dim, bias=True, gather_output=False, return_bias=False) + self.v_img = ColumnParallelLinear(dim, dim, bias=True, gather_output=False, return_bias=False) + self.norm_k_img = DistributedRMSNorm(self.tp_inner_dim, eps=eps) if qk_norm else nn.Identity() + self.attn = Attention( + self.tp_num_heads, + self.head_dim, + causal=False, + softmax_scale=self.head_dim**-0.5, + skip_sequence_parallel=True, + ) + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor, + context_lens: torch.Tensor | None = None, + crossattn_cache: dict | None = None, + ) -> torch.Tensor: + del context_lens + context_img = context[:, :257] + context = context[:, 257:] + n, d = self.tp_num_heads, self.head_dim + q = self.norm_q(self.q(x)).unflatten(2, (n, d)) + if crossattn_cache is not None: + if not crossattn_cache["is_init"]: + crossattn_cache["is_init"] = True + k = self.norm_k(self.k(context)).unflatten(2, (n, d)) + v = self.v(context).unflatten(2, (n, d)) + crossattn_cache["k"] = k + crossattn_cache["v"] = v + else: + k = crossattn_cache["k"] + v = crossattn_cache["v"] + else: + k = self.norm_k(self.k(context)).unflatten(2, (n, d)) + v = self.v(context).unflatten(2, (n, d)) + x = self.attn(q, k, v) + k_img = self.norm_k_img(self.k_img(context_img)).unflatten(2, (n, d)) + v_img = self.v_img(context_img).unflatten(2, (n, d)) + img_x = self.attn(q, k_img, v_img) + x = x.flatten(2) + img_x = img_x.flatten(2) + x = x + img_x + x = self.o(x) + return x + + +WAN_CROSSATTENTION_CLASSES = { + "t2v_cross_attn": WanT2VCrossAttention, + "i2v_cross_attn": WanI2VCrossAttention, +} + + +# ── Self-Attention with causal masking + KV cache ─────────────────── + + +class CausalWanSelfAttention(nn.Module): + """Causal self-attention with KV cache + action/state tokens.""" + + def __init__( + self, + dim: int, + num_heads: int, + frame_seqlen: int, + local_attn_size: int = -1, + sink_size: int = 0, + num_frame_per_block: int = 1, + qk_norm: bool = True, + eps: float = 1e-6, + num_action_per_block: int = 32, + num_state_per_block: int = 1, + ) -> None: + super().__init__() + if dim % num_heads != 0: + raise ValueError(f"dim={dim} must be divisible by num_heads={num_heads}.") + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + tp_size = get_tensor_model_parallel_world_size() + if num_heads % tp_size != 0: + raise ValueError(f"num_heads={num_heads} must be divisible by tp_size={tp_size}.") + self.tp_num_heads = num_heads // tp_size + self.tp_inner_dim = self.tp_num_heads * self.head_dim + self.local_attn_size = local_attn_size + self.num_frame_per_block = num_frame_per_block + self.frame_seqlen = frame_seqlen + self.num_action_per_block = num_action_per_block + self.num_state_per_block = num_state_per_block + self.max_attention_size = 21 * frame_seqlen if local_attn_size == -1 else local_attn_size * frame_seqlen + self.q = ColumnParallelLinear(dim, dim, bias=True, gather_output=False, return_bias=False) + self.k = ColumnParallelLinear(dim, dim, bias=True, gather_output=False, return_bias=False) + self.v = ColumnParallelLinear(dim, dim, bias=True, gather_output=False, return_bias=False) + self.o = RowParallelLinear(dim, dim, bias=True, input_is_parallel=True, return_bias=False) + self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps) if qk_norm else nn.Identity() + self.attn = Attention( + self.tp_num_heads, + self.head_dim, + causal=False, + softmax_scale=self.head_dim**-0.5, + skip_sequence_parallel=True, + ) + + def forward( + self, + x: torch.Tensor, + freqs: torch.Tensor, + freqs_action: torch.Tensor, + freqs_state: torch.Tensor, + action_register_length: int | None, + kv_cache: torch.Tensor | None = None, + current_start_frame: int = 0, + is_tf: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Inference-only forward (KV cache path).""" + n, d = self.tp_num_heads, self.head_dim + + q = self.norm_q(self.q(x)).unflatten(2, (n, d)) + k = self.norm_k(self.k(x)).unflatten(2, (n, d)) + v = self.v(x).unflatten(2, (n, d)) + + updated_kv_cache: torch.Tensor | None = None + + if kv_cache is None: + raise RuntimeError("Inference only: kv_cache is required.") + + action_state_index = max(0, (current_start_frame - 1) // self.num_frame_per_block) + + roped_query = causal_rope_action_apply( + q, + freqs, + freqs_action, + freqs_state, + action_register_length, + self.num_action_per_block, + self.num_state_per_block, + action_state_index, + ).type_as(v) + roped_key = causal_rope_action_apply( + k, + freqs, + freqs_action, + freqs_state, + action_register_length, + self.num_action_per_block, + self.num_state_per_block, + action_state_index, + ).type_as(v) + + roped_action_query = None + roped_action_key = None + action_v = None + + if action_register_length is not None: + roped_action_query = roped_query[:, -action_register_length:] + roped_query = roped_query[:, :-action_register_length] + roped_action_key = roped_key[:, -action_register_length:] + roped_key = roped_key[:, :-action_register_length] + action_v = v[:, -action_register_length:] + v = v[:, :-action_register_length] + + updated_k = kv_cache[0] + updated_v = kv_cache[1] + new_k = torch.cat([updated_k, roped_key], dim=1) + new_v = torch.cat([updated_v, v], dim=1) + new_k = new_k[:, -self.max_attention_size :] + new_v = new_v[:, -self.max_attention_size :] + + if action_register_length is not None: + q_cat = torch.cat([roped_query, roped_action_query], dim=1) + k_cat = torch.cat([new_k, roped_action_key], dim=1) + v_cat = torch.cat([new_v, action_v], dim=1) + else: + q_cat = roped_query + k_cat = new_k + v_cat = new_v + + x = self.attn(q_cat, k_cat, v_cat) + updated_kv_cache = torch.stack([new_k, new_v], dim=0) + + x = x.flatten(2) + x = self.o(x) + return x, updated_kv_cache + + +# ── Attention Block ───────────────────────────────────────────────── + + +class CausalWanAttentionBlock(nn.Module): + """Transformer block: self-attn + cross-attn + FFN with 6-param modulation.""" + + def __init__( + self, + cross_attn_type: str, + dim: int, + ffn_dim: int, + num_heads: int, + frame_seqlen: int, + local_attn_size: int = -1, + sink_size: int = 0, + num_frame_per_block: int = 1, + qk_norm: bool = True, + cross_attn_norm: bool = False, + eps: float = 1e-6, + num_action_per_block: int = 32, + num_state_per_block: int = 1, + ) -> None: + super().__init__() + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = CausalWanSelfAttention( + dim=dim, + num_heads=num_heads, + frame_seqlen=frame_seqlen, + local_attn_size=local_attn_size, + sink_size=sink_size, + num_frame_per_block=num_frame_per_block, + qk_norm=qk_norm, + eps=eps, + num_action_per_block=num_action_per_block, + num_state_per_block=num_state_per_block, + ) + self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, num_heads, (-1, -1), qk_norm, eps) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential( + ColumnParallelLinear(dim, ffn_dim, bias=True, gather_output=False, return_bias=False), + nn.GELU(approximate="tanh"), + RowParallelLinear(ffn_dim, dim, bias=True, input_is_parallel=True, return_bias=False), + ) + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + x: torch.Tensor, + e: torch.Tensor, + freqs: torch.Tensor, + freqs_action: torch.Tensor, + freqs_state: torch.Tensor, + context: torch.Tensor, + action_register_length: int | None = None, + kv_cache: torch.Tensor | None = None, + crossattn_cache: dict | None = None, + current_start_frame: int = 0, + is_tf: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + e = (self.modulation.unsqueeze(1) + e).chunk(6, dim=2) + + y, updated_kv_cache = self.self_attn( + x=(self.norm1(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2)), + freqs=freqs, + freqs_action=freqs_action, + freqs_state=freqs_state, + action_register_length=action_register_length, + kv_cache=kv_cache, + is_tf=is_tf, + current_start_frame=current_start_frame, + ) + x = x + (y * e[2].squeeze(2)) + + x = x + self.cross_attn(self.norm3(x), context, crossattn_cache=crossattn_cache) + y = self.ffn(self.norm2(x) * (1 + e[4].squeeze(2)) + e[3].squeeze(2)) + x = x + (y * e[5].squeeze(2)) + return x, updated_kv_cache + + +# ── Output Head ───────────────────────────────────────────────────── + + +class CausalHead(nn.Module): + """Output norm + linear with 2-param modulation. + Runs once per step (not TP-critical), uses nn.Linear. + """ + + def __init__(self, dim: int, out_dim: int, patch_size: tuple, eps: float = 1e-6) -> None: + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + out_channels = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_channels) + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor: + """ + Args: + x: [B, L1, C] + e: [B, F, 1, C] (time embedding, unsqueezed) + """ + e = (self.modulation.unsqueeze(1) + e).chunk(2, dim=2) + x = self.head(self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2)) + return x + + +# ── Main Model ────────────────────────────────────────────────────── + + +class CausalWanModel(nn.Module): + """Causal video diffusion transformer for DreamZero. + + Architecture (14B): 40 layers, dim=5120, heads=40, ffn=13824 + """ + + def __init__( + self, + model_type: str = "t2v", + patch_size: tuple[int, int, int] = (1, 2, 2), + frame_seqlen: int = 220, + text_len: int = 512, + in_dim: int = 16, + dim: int = 2048, + ffn_dim: int = 8192, + freq_dim: int = 256, + text_dim: int = 4096, + out_dim: int = 16, + num_heads: int = 16, + num_layers: int = 32, + max_chunk_size: int = -1, + sink_size: int = 0, + qk_norm: bool = True, + cross_attn_norm: bool = True, + eps: float = 1e-6, + num_frame_per_block: int = 1, + action_dim: int = 32, + num_registers: int = 8, + max_state_dim: int = 64, + max_num_embodiments: int = 32, + hidden_size: int = 1024, + diffusion_model_pretrained_path: str | None = None, + num_action_per_block: int = 32, + num_state_per_block: int = 1, + ) -> None: + super().__init__() + if model_type not in ["t2v", "i2v", "ti2v"]: + raise ValueError(f"Unsupported model_type={model_type!r}; expected one of ['t2v', 'i2v', 'ti2v'].") + self.model_type = model_type + self.patch_size = patch_size + self.frame_seqlen = frame_seqlen + self.text_len = text_len + self.dim = dim + self.freq_dim = freq_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.local_attn_size = max_chunk_size * num_frame_per_block + 1 if max_chunk_size != -1 else -1 + self.num_frame_per_block = num_frame_per_block + self.action_dim = action_dim + self.num_action_per_block = num_action_per_block + self.num_state_per_block = num_state_per_block + + max_num_embodiments_local = 1 + self.state_encoder = CategorySpecificMLP( + num_categories=max_num_embodiments_local, + input_dim=max_state_dim, + hidden_dim=hidden_size, + output_dim=dim, + ) + self.action_encoder = MultiEmbodimentActionEncoder( + action_dim=action_dim, + hidden_size=dim, + num_embodiments=max_num_embodiments_local, + ) + self.action_decoder = CategorySpecificMLP( + num_categories=max_num_embodiments_local, + input_dim=dim, + hidden_dim=hidden_size, + output_dim=action_dim, + ) + + # Disable the Conv3d GEMM rewrite for patch embedding. + self.patch_embedding = Conv3dLayer( + in_dim, + dim, + kernel_size=patch_size, + stride=patch_size, + ) + self.patch_embedding.enable_linear = False + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), + nn.GELU(approximate="tanh"), + nn.Linear(dim, dim), + ) + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim), + ) + self.time_projection = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, dim * 6), + ) + + cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn" + self.blocks = nn.ModuleList( + [ + CausalWanAttentionBlock( + cross_attn_type, + dim, + ffn_dim, + num_heads, + frame_seqlen, + self.local_attn_size, + sink_size, + num_frame_per_block, + qk_norm, + cross_attn_norm, + eps, + num_action_per_block, + num_state_per_block, + ) + for _ in range(num_layers) + ] + ) + + self.head = CausalHead(dim, out_dim, patch_size, eps) + + if dim % num_heads != 0: + raise ValueError(f"dim={dim} must be divisible by num_heads={num_heads}.") + if (dim // num_heads) % 2 != 0: + raise ValueError(f"dim // num_heads must be even, got {dim // num_heads}.") + d = dim // num_heads + self.freqs_action = rope_params(1024 * 10, d) + self.freqs_state = rope_params(1024, d) + self.freqs = [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + ] + + if model_type == "i2v": + self.img_emb = MLPProj(1280, dim) + + self.init_weights() + + def init_weights(self) -> None: + """Initialize parameters.""" + + def _init_linear_like(module: nn.Module) -> None: + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + return + + if isinstance(module, (ColumnParallelLinear, RowParallelLinear)): + fan_in = module.input_size + fan_out = module.output_size + bound = math.sqrt(6.0 / float(fan_in + fan_out)) + nn.init.uniform_(module.weight, -bound, bound) + if module.bias is not None: + nn.init.zeros_(module.bias) + + for module in self.modules(): + _init_linear_like(module) + + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + if self.patch_embedding.bias is not None: + fan_in = self.patch_embedding.in_channels * math.prod(self.patch_embedding.kernel_size) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.patch_embedding.bias, -bound, bound) + + for module in self.text_embedding.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=0.02) + + for module in self.time_embedding.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=0.02) + + nn.init.zeros_(self.head.head.weight) + + def _create_freqs(self, grid_size: torch.Tensor, start_frame: int) -> torch.Tensor: + """Create 3D RoPE frequency tensor.""" + device = self.patch_embedding.weight.device + if any(freq.device != device for freq in self.freqs): + self.freqs = [freq.to(device) for freq in self.freqs] + if self.freqs_action.device != device: + self.freqs_action = self.freqs_action.to(device) + if self.freqs_state.device != device: + self.freqs_state = self.freqs_state.to(device) + + f, h, w = grid_size.tolist() + freqs = torch.cat( + [ + self.freqs[0][start_frame : start_frame + f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(f * h * w, 1, -1) + return freqs + + def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor) -> torch.Tensor: + """Reconstruct video from patch embeddings.""" + B = x.shape[0] + c = self.out_dim + grid_size = grid_size.tolist() + expected_seq_len = math.prod(grid_size) + if x.shape[1] != expected_seq_len: + raise ValueError(f"x sequence length must equal product(grid_size)={expected_seq_len}, got {x.shape[1]}.") + x = x.view(B, *grid_size, *self.patch_size, c) + x = torch.einsum("bfhwpqrc->bcfphqwr", x) + x = x.reshape(B, c, *[i * j for i, j in zip(grid_size, self.patch_size)]) + return x + + def _forward_blocks( + self, + x: torch.Tensor, + seq_len: int, + freqs: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + clip_feature: torch.Tensor | None, + embodiment_id: torch.Tensor | None, + action: torch.Tensor | None, + timestep_action: torch.Tensor | None, + state: torch.Tensor | None, + kv_cache: list[torch.Tensor], + current_start_frame: int, + ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor]]: + x = x.flatten(start_dim=2).transpose(1, 2) + B = x.shape[0] + F_t = timestep.shape[1] + + if action is not None: + # Current DreamZero checkpoints have one local action/state adapter. + # Global embodiment IDs are used by transforms and normalization. + adapter_category_id = torch.zeros(B, dtype=torch.long, device=x.device) + action_features = self.action_encoder(action, timestep_action, adapter_category_id) + state_features = self.state_encoder(state, adapter_category_id) + action_register = torch.cat([action_features, state_features], dim=1) + action_length = action_features.shape[1] + action_register_length = action_register.shape[1] + x = torch.cat([x, action_register], dim=1) + else: + action_length = 0 + action_register_length = None + + timestep = timestep.unsqueeze(-1).expand(B, F_t, seq_len // F_t).reshape(B, -1) + if action is not None: + if timestep_action is None or state is None: + raise RuntimeError("timestep_action and state are required when action is provided.") + state_features_t = self.state_encoder(state, adapter_category_id) + stride = timestep_action.shape[1] // state_features_t.shape[1] + timestep_state = timestep_action[:, ::stride] + timestep = torch.cat([timestep, timestep_action, timestep_state], dim=1) + + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).type_as(x)) + e = e.unflatten(dim=0, sizes=(B, -1)) + e0 = self.time_projection(e) + e0 = e0.unflatten(dim=2, sizes=(6, self.dim)) + + context = self.text_embedding(context) + if clip_feature is not None: + clip_embedding = self.img_emb(clip_feature) + context = torch.cat([clip_embedding, context], dim=1) + + updated_kv_caches: list[torch.Tensor] = [] + for block_index, block in enumerate(self.blocks): + x, updated_kv_cache = block( + x=x, + e=e0, + freqs=freqs, + freqs_action=self.freqs_action, + freqs_state=self.freqs_state, + context=context, + action_register_length=action_register_length, + kv_cache=kv_cache[block_index] if kv_cache else None, + current_start_frame=current_start_frame, + ) + updated_kv_caches.append(updated_kv_cache) + + if action is not None: + action_noise_pred = x[:, seq_len : seq_len + action_length] + action_noise_pred = self.action_decoder(action_noise_pred, adapter_category_id) + else: + action_noise_pred = None + + x_video = x[:, :seq_len] + e_video = e[:, :seq_len] + x_video = self.head(x_video, e_video.unsqueeze(2)) + + return x_video, action_noise_pred, updated_kv_caches + + def _forward_inference( + self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + seq_len: int, + kv_cache: list[torch.Tensor], + crossattn_cache: list[torch.Tensor], + current_start_frame: int, + y: torch.Tensor | None = None, + clip_feature: torch.Tensor | None = None, + action: torch.Tensor | None = None, + timestep_action: torch.Tensor | None = None, + state: torch.Tensor | None = None, + embodiment_id: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor]]: + if self.model_type == "i2v": + if clip_feature is None or y is None: + raise RuntimeError("clip_feature and y are required for i2v inference.") + if context.shape[1] != self.text_len: + raise ValueError(f"context length must be {self.text_len}, got {context.shape[1]}.") + + if y is not None: + x = torch.cat([x, y.to(dtype=x.dtype)], dim=1) + + x = self.patch_embedding(x) + grid_size = torch.tensor(x.shape[2:], dtype=torch.long) + freqs = self._create_freqs(grid_size, current_start_frame) + + x_video, action_noise_pred, updated_kv_caches = self._forward_blocks( + x=x, + seq_len=seq_len, + freqs=freqs, + timestep=timestep, + context=context, + clip_feature=clip_feature, + embodiment_id=embodiment_id, + action=action, + timestep_action=timestep_action, + state=state, + kv_cache=kv_cache, + current_start_frame=current_start_frame, + ) + + x_video = x_video.clone() + if action_noise_pred is not None: + action_noise_pred = action_noise_pred.clone() + + video_noise_pred = self.unpatchify(x_video, grid_size) + return video_noise_pred, action_noise_pred, updated_kv_caches + + def forward(self, *args: Any, **kwargs: Any): + """Inference only. Requires kv_cache.""" + return self._forward_inference(*args, **kwargs) diff --git a/vllm_omni/diffusion/models/dreamzero/image_encoder.py b/vllm_omni/diffusion/models/dreamzero/image_encoder.py new file mode 100644 index 00000000000..8a70b5f1eaf --- /dev/null +++ b/vllm_omni/diffusion/models/dreamzero/image_encoder.py @@ -0,0 +1,243 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""DreamZero image encoder. + +Only the visual tower used by DreamZero I2V inference is ported here. The +checkpoint keys under `action_head.image_encoder.*` load via simple prefix +stripping. +""" + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as T + + +class DreamZeroLayerNorm(nn.LayerNorm): + """LayerNorm that preserves the input dtype.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x).type_as(x) + + +class DreamZeroVisionSelfAttention(nn.Module): + """Self-attention for the vision tower.""" + + def __init__( + self, + dim: int, + num_heads: int, + proj_dropout: float = 0.0, + ) -> None: + super().__init__() + if dim % num_heads != 0: + raise ValueError(f"dim={dim} must be divisible by num_heads={num_heads}.") + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.proj_dropout = proj_dropout + + self.to_qkv = nn.Linear(dim, dim * 3) + self.proj = nn.Linear(dim, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, _ = x.shape + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + x = F.scaled_dot_product_attention(q, k, v) + x = x.permute(0, 2, 1, 3).reshape(batch_size, seq_len, self.dim) + x = self.proj(x) + return F.dropout(x, self.proj_dropout, self.training) + + +class DreamZeroVisionAttentionBlock(nn.Module): + """Attention block for the vision tower.""" + + def __init__( + self, + dim: int, + mlp_ratio: float, + num_heads: int, + post_norm: bool = False, + activation: str = "gelu", + proj_dropout: float = 0.0, + norm_eps: float = 1e-5, + ) -> None: + super().__init__() + if activation != "gelu": + raise ValueError(f"DreamZero image encoder uses GELU; got activation={activation!r}.") + self.post_norm = post_norm + hidden_dim = int(dim * mlp_ratio) + + self.norm1 = DreamZeroLayerNorm(dim, eps=norm_eps) + self.attn = DreamZeroVisionSelfAttention( + dim, + num_heads, + proj_dropout=proj_dropout, + ) + self.norm2 = DreamZeroLayerNorm(dim, eps=norm_eps) + self.mlp = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, dim), + nn.Dropout(proj_dropout), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.post_norm: + x = x + self.norm1(self.attn(x)) + x = x + self.norm2(self.mlp(x)) + else: + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class DreamZeroVisionTransformer(nn.Module): + """Vision transformer used by the image encoder.""" + + def __init__( + self, + image_size: int = 224, + patch_size: int = 14, + dim: int = 1280, + mlp_ratio: float = 4.0, + out_dim: int = 1024, + num_heads: int = 16, + num_layers: int = 32, + pool_type: str = "token", + pre_norm: bool = True, + post_norm: bool = False, + activation: str = "gelu", + proj_dropout: float = 0.0, + embedding_dropout: float = 0.0, + norm_eps: float = 1e-5, + ) -> None: + super().__init__() + if pool_type != "token": + raise ValueError(f"DreamZero image encoder only supports pool_type='token', got {pool_type!r}.") + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = (image_size // patch_size) ** 2 + self.dim = dim + self.num_heads = num_heads + self.num_layers = num_layers + self.pool_type = pool_type + + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = nn.Conv2d( + 3, + dim, + kernel_size=patch_size, + stride=patch_size, + bias=not pre_norm, + ) + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.pos_embedding = nn.Parameter( + gain * torch.randn(1, self.num_patches + 1, dim), + ) + self.dropout = nn.Dropout(embedding_dropout) + self.pre_norm = DreamZeroLayerNorm(dim, eps=norm_eps) if pre_norm else None + self.transformer = nn.Sequential( + *[ + DreamZeroVisionAttentionBlock( + dim=dim, + mlp_ratio=mlp_ratio, + num_heads=num_heads, + post_norm=post_norm, + activation=activation, + proj_dropout=proj_dropout, + norm_eps=norm_eps, + ) + for _ in range(num_layers) + ] + ) + self.post_norm = DreamZeroLayerNorm(dim, eps=norm_eps) + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + + def forward(self, x: torch.Tensor, use_31_block: bool = False) -> torch.Tensor: + batch_size = x.shape[0] + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) + x = torch.cat( + [ + self.cls_embedding.expand(batch_size, -1, -1).to(dtype=x.dtype, device=x.device), + x, + ], + dim=1, + ) + x = self.dropout(x + self.pos_embedding.to(dtype=x.dtype, device=x.device)) + if self.pre_norm is not None: + x = self.pre_norm(x) + + if use_31_block: + return self.transformer[:-1](x) + return self.transformer(x) + + +class _DreamZeroCLIPContainer(nn.Module): + """Container matching checkpoint names under `model.visual.*`.""" + + def __init__(self) -> None: + super().__init__() + self.log_scale = nn.Parameter(torch.ones(())) + self.visual = DreamZeroVisionTransformer( + image_size=224, + patch_size=14, + dim=1280, + mlp_ratio=4.0, + out_dim=1024, + num_heads=16, + num_layers=32, + pool_type="token", + pre_norm=True, + post_norm=False, + activation="gelu", + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5, + ) + + +class DreamZeroImageEncoder(nn.Module): + """Image encoder wrapper.""" + + def __init__(self) -> None: + super().__init__() + self.model = _DreamZeroCLIPContainer() + # returns a composed transform whose last stage is CLIP normalization. + self.transforms = T.Compose( + [ + T.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ), + ] + ) + + def encode_image(self, videos: torch.Tensor) -> torch.Tensor: + """Encode images for I2V conditioning.""" + size = (self.model.visual.image_size,) * 2 + videos = torch.cat( + [ + F.interpolate( + frame_batch, + size=size, + mode="bicubic", + align_corners=False, + ) + for frame_batch in videos + ] + ) + videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) + + param_dtype = next(iter(self.model.visual.parameters())).dtype + videos = videos.to(dtype=param_dtype) + out = self.model.visual(videos, use_31_block=True) + return out.clone() diff --git a/vllm_omni/diffusion/models/dreamzero/pipeline_dreamzero.py b/vllm_omni/diffusion/models/dreamzero/pipeline_dreamzero.py new file mode 100644 index 00000000000..6bb97aab640 --- /dev/null +++ b/vllm_omni/diffusion/models/dreamzero/pipeline_dreamzero.py @@ -0,0 +1,1310 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""DreamZero pipeline for vllm-omni. + +Entry point for DiffusionEngine.step() → pipeline.forward(req) +""" + +from __future__ import annotations + +import copy +import json +import logging +import os +import re as re_module +from collections import OrderedDict +from collections.abc import Iterable + +import numpy as np +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer, UMT5Config, UMT5EncoderModel +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_wan import ( + DistributedAutoencoderKLWan, +) +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.parallel_state import get_classifier_free_guidance_world_size +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.dreamzero.causal_wan_model import CausalWanModel +from vllm_omni.diffusion.models.dreamzero.image_encoder import DreamZeroImageEncoder +from vllm_omni.diffusion.models.dreamzero.state_dreamzero import DreamZeroState +from vllm_omni.diffusion.models.dreamzero.transform import ( + DEFAULT_EMBODIMENT, + ensure_transforms_loaded, +) +from vllm_omni.diffusion.models.dreamzero.transform.base import get_transform +from vllm_omni.diffusion.models.dreamzero.utils import ( + DEFAULT_CFG_SCALE, + DEFAULT_EMBODIMENT_NAME_TO_ID, + DEFAULT_NEGATIVE_PROMPT, + DEFAULT_NUM_INFERENCE_STEPS, + DEFAULT_SEED, + DEFAULT_SIGMA_SHIFT, +) +from vllm_omni.diffusion.models.schedulers.scheduling_flow_unipc_multistep import FlowUniPCMultistepScheduler +from vllm_omni.diffusion.request import OmniDiffusionRequest + +logger = logging.getLogger(__name__) +MAX_DREAMZERO_SESSIONS = 64 + + +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- + + +class VideoActionScheduler: + """Wraps video + action schedulers into single .step() interface.""" + + def __init__(self, video_scheduler, action_scheduler): + self.video_scheduler = video_scheduler + self.action_scheduler = action_scheduler + + def step(self, noise_pred, t, latents, return_dict=False, generator=None): + video_out = self.video_scheduler.step( + noise_pred[0], + t[0], + latents[0], + return_dict=False, + generator=generator, + )[0] + action_out = self.action_scheduler.step( + noise_pred[1], + t[1], + latents[1], + return_dict=False, + generator=generator, + )[0] + return ((video_out, action_out),) + + +# --------------------------------------------------------------------------- +# DreamZeroPipeline +# --------------------------------------------------------------------------- + + +class DreamZeroPipeline(nn.Module, CFGParallelMixin): + """DreamZero world model pipeline. + + Multi-output: predict_noise() returns (video_pred, action_pred). + CFG: video gets standard CFG, action takes positive branch only. + State: DreamZeroState manages KV cache + frame buffer across forward() calls. + """ + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = "") -> None: + """Initialize pipeline components. + + DreamZero root checkpoint layout (GEAR-Dreams/DreamZero-DROID): + config.json — root config (action_head_cfg, architectures, etc.) + model-*.safetensors — all learned weights (action_head.{model,text_encoder,image_encoder,vae}.*) + experiment_cfg/metadata.json — per-embodiment action normalization stats + vae/ — symlink to Wan2.1 VAE (diffusers-compatible) + + Components are instantiated from config (not from_pretrained), then filled + by load_weights() which reads root safetensors and remaps key prefixes. + Exceptions: + - tokenizer loads from `google/umt5-xxl` + - VAE uses `DistributedAutoencoderKLWan` as the local execution module. + It can be bootstrapped either from an explicit diffusers source + (`od_config.model_paths["vae"]`) or directly from constructor defaults + that match Wan2.1 VAE, after which DreamZero root + `action_head.vae.*` weights are remapped onto that module in + `load_weights()` + """ + super().__init__() + + model_path = od_config.model + model_config = od_config.model_config + local_files_only = os.path.exists(model_path) + self.od_config = od_config + ensure_transforms_loaded() + self.default_robot_embodiment = model_config.get( + "default_robot_embodiment", + DEFAULT_EMBODIMENT, + ) + + root_cfg = self._load_repo_json(model_path, "config.json", local_files_only) + if root_cfg is None: + raise ValueError(f"DreamZero requires root config.json in {model_path}.") + action_head_cfg = root_cfg["action_head_cfg"] + ah_config = action_head_cfg["config"] + diffusion_model_cfg = ah_config["diffusion_model_cfg"] + + # ---- Tokenizer ---- + tokenizer_source = od_config.model_paths.get("tokenizer", "google/umt5-xxl") + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_source) + + # Instantiate from config; weights load through `load_weights()`. + umt5_config = UMT5Config( + d_model=4096, + d_ff=10240, + num_heads=64, + num_layers=24, + vocab_size=256384, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + dense_act_fn="gelu_new", + feed_forward_proj="gated-gelu", + is_encoder_decoder=False, + ) + self.text_encoder = UMT5EncoderModel(umt5_config) + + self.image_encoder = DreamZeroImageEncoder() + + # Build a compatible VAE module, then fill it through `load_weights()`. + vae_source = od_config.model_paths.get("vae") + if vae_source: + self.vae = DistributedAutoencoderKLWan.from_pretrained( + vae_source, + torch_dtype=torch.float32, + ) + elif local_files_only and os.path.isdir(os.path.join(model_path, "vae")): + self.vae = DistributedAutoencoderKLWan.from_pretrained( + model_path, + subfolder="vae", + torch_dtype=torch.float32, + ) + else: + self.vae = DistributedAutoencoderKLWan() + self.vae.init_distributed() + if not ( + getattr(od_config, "enable_cpu_offload", False) or getattr(od_config, "enable_layerwise_offload", False) + ): + self.vae = self.vae.to(device=get_local_device(), dtype=od_config.dtype) + self.register_buffer( + "vae_latents_mean", + torch.tensor(self.vae.config.latents_mean, dtype=torch.float32).view(1, -1, 1, 1, 1), + persistent=False, + ) + self.register_buffer( + "vae_latents_inv_std", + (1.0 / torch.tensor(self.vae.config.latents_std, dtype=torch.float32)).view(1, -1, 1, 1, 1), + persistent=False, + ) + + # Filter out keys not accepted by `CausalWanModel.__init__`. + transformer_kwargs = {k: v for k, v in diffusion_model_cfg.items() if k not in ("_convert_", "_target_")} + transformer_kwargs["action_dim"] = ah_config["action_dim"] + transformer_kwargs["max_state_dim"] = ah_config["max_state_dim"] + transformer_kwargs["num_frame_per_block"] = ah_config["num_frame_per_block"] + self.transformer = CausalWanModel(**transformer_kwargs) + + self.scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, + shift=1, + use_dynamic_shifting=False, + ) + + self._states: OrderedDict[str, DreamZeroState] = OrderedDict() + self._max_session_states = MAX_DREAMZERO_SESSIONS + self.state = self._get_or_create_state("default") + + # Keep runtime inference settings separate from the training-time config. + self.num_inference_steps: int = model_config.get( + "num_inference_steps", + DEFAULT_NUM_INFERENCE_STEPS, + ) + self.cfg_scale: float = model_config.get("cfg_scale", DEFAULT_CFG_SCALE) + self.sigma_shift: float = model_config.get("sigma_shift", DEFAULT_SIGMA_SHIFT) + self.num_frames: int = ah_config["num_frames"] + self.num_frame_per_block: int = ah_config["num_frame_per_block"] + self.action_horizon: int = ah_config["action_horizon"] + + self.decouple_inference_noise: bool = ah_config["decouple_inference_noise"] + self.video_inference_final_noise: float = ah_config["video_inference_final_noise"] + + self.seed: int = model_config.get("seed", DEFAULT_SEED) + + # Model-level constants for state/action padding. + self.max_state_dim: int = ah_config["max_state_dim"] + self.max_action_dim: int = ah_config["max_action_dim"] + + self.negative_prompt: str = model_config.get("negative_prompt", DEFAULT_NEGATIVE_PROMPT) + + # Embodiment name → numeric ID mapping (model knowledge) + self.embodiment_name_to_id: dict[str, int] = model_config.get( + "embodiment_name_to_id", + DEFAULT_EMBODIMENT_NAME_TO_ID, + ) + + # Prefer root `experiment_cfg/metadata.json`, then `model_config`. + stats_path = model_config.get("action_norm_stats_path") + metadata = self._load_repo_json(model_path, "experiment_cfg/metadata.json", local_files_only) + if metadata is not None: + self.action_norm_stats = self._parse_action_norm_stats(metadata) + self.state_norm_stats = self._parse_state_norm_stats(metadata) + elif stats_path: + self.action_norm_stats = self._load_action_norm_stats(stats_path) + self.state_norm_stats = {} + else: + self.action_norm_stats: dict[str, dict[str, torch.Tensor]] = {} + self.state_norm_stats: dict[str, dict[str, torch.Tensor]] = {} + + # Whether model uses relative actions (need to add back last state) + self.relative_action: bool = model_config.get("relative_action", True) + # Number of action dims that are relative (DROID: 7 = joint only, gripper is absolute) + self.relative_action_dim: int = model_config.get("relative_action_dim", 7) + + self._weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=model_path, + subfolder=None, + revision=None, + prefix="", + fall_back_to_pt=False, + allow_patterns_overrides=[ + "model-*.safetensors", + "model.safetensors", + ], + ), + ] + + def _get_or_create_state(self, session_id: str | None) -> DreamZeroState: + session_key = str(session_id or "default") + state = self._states.get(session_key) + if state is None: + state = DreamZeroState() + self._states[session_key] = state + max_states = getattr(self, "_max_session_states", MAX_DREAMZERO_SESSIONS) + while len(self._states) > max_states: + self._states.popitem(last=False) + else: + self._states.move_to_end(session_key) + return state + + # ----------------------------------------------------------------------- + # Root config loading + # ----------------------------------------------------------------------- + + @staticmethod + def _load_repo_json(model_path: str, relative_path: str, local_files_only: bool) -> dict | None: + """Load a JSON file from a local checkpoint directory or HF repo.""" + if local_files_only and os.path.isdir(model_path): + json_path = os.path.join(model_path, relative_path) + if not os.path.exists(json_path): + return None + with open(json_path) as f: + return json.load(f) + + try: + json_path = hf_hub_download(model_path, relative_path) + with open(json_path) as f: + return json.load(f) + except Exception: + logger.warning("Failed to load %s from %s", relative_path, model_path) + return None + + # ----------------------------------------------------------------------- + # CFGParallelMixin overrides + # ----------------------------------------------------------------------- + + def predict_noise(self, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: + """Call CausalWanModel, return (video_pred, action_pred).""" + video_pred, action_pred, updated_kv_caches = self.transformer( + x=kwargs["hidden_states"], + timestep=kwargs["timestep_video"], + context=kwargs["encoder_hidden_states"], + seq_len=kwargs["seq_len"], + kv_cache=kwargs["kv_cache"], + crossattn_cache=kwargs["crossattn_cache"], + current_start_frame=kwargs["current_start_frame"], + y=kwargs.get("y"), + clip_feature=kwargs.get("clip_feature"), + action=kwargs.get("action"), + timestep_action=kwargs.get("timestep_action"), + state=kwargs.get("state_features"), + embodiment_id=kwargs.get("embodiment_id"), + ) + if kwargs.get("update_kv_cache", False) and updated_kv_caches: + state = kwargs.get("dreamzero_state", self.state) + is_neg = kwargs.get("is_negative", False) + for i, kv in enumerate(updated_kv_caches): + state.update_kv_cache(i, kv, is_negative=is_neg) + + video_pred = video_pred.clone() + if action_pred is not None: + action_pred = action_pred.clone() + else: + batch_size = kwargs["hidden_states"].shape[0] + action_pred = torch.empty( + batch_size, + 0, + self.transformer.action_dim, + device=video_pred.device, + dtype=video_pred.dtype, + ) # CFG-parallel-safe dummy action pred + return (video_pred, action_pred) + + def combine_cfg_noise( + self, + positive_noise_pred: torch.Tensor | tuple[torch.Tensor, ...], + negative_noise_pred: torch.Tensor | tuple[torch.Tensor, ...], + true_cfg_scale: float, + cfg_normalize: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: + """Video: standard CFG. Action: positive only (no CFG). + action = cond only (no uncond blending) + """ + (video_pos, action_pos) = positive_noise_pred + (video_neg, _) = negative_noise_pred + video_combined = super().combine_cfg_noise(video_pos, video_neg, true_cfg_scale, cfg_normalize) + return (video_combined, action_pos) + + # ----------------------------------------------------------------------- + # ----------------------------------------------------------------------- + + def _synchronize_cfg_parallel_step_output( + self, + latents: tuple[torch.Tensor, torch.Tensor], + do_true_cfg: bool, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Post-step sync: .contiguous() + cuda.synchronize()""" + latents = tuple(t.contiguous() for t in latents) + if do_true_cfg and get_classifier_free_guidance_world_size() > 1: + device = next((t.device for t in latents if t.is_cuda), None) + if device is not None: + torch.cuda.current_stream(device).synchronize() + return latents + + # ----------------------------------------------------------------------- + # Video preprocessing + # ----------------------------------------------------------------------- + + def _preprocess_video(self, videos: torch.Tensor) -> torch.Tensor: + """uint8 [B,T,H,W,C] → bfloat16 [B,C,T,H,W] normalized to [-1,1].""" + videos = videos.permute(0, 4, 1, 2, 3) + if videos.dtype == torch.uint8: + videos = videos.float() / 255.0 + # Cast to bf16 before normalization to preserve input rounding. + videos = videos.to(dtype=torch.bfloat16) + b, c, t, h, w = videos.shape + videos = videos.permute(0, 2, 1, 3, 4) + videos = videos.reshape(b * t, c, h, w) + videos = videos * 2.0 - 1.0 + videos = videos.reshape(b, t, c, h, w).permute(0, 2, 1, 3, 4) + return videos.to(dtype=torch.bfloat16) + + # ----------------------------------------------------------------------- + # Text encoding + # ----------------------------------------------------------------------- + + def _encode_text(self, text_tokens: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + """Encode text prompt via UMT5.""" + seq_lens = attention_mask.gt(0).sum(dim=1).long() + prompt_emb = self.text_encoder( + text_tokens, + attention_mask, + ).last_hidden_state + prompt_emb = prompt_emb.clone().to(dtype=torch.bfloat16) + for i, v in enumerate(seq_lens): + prompt_emb[:, v:] = 0 + return prompt_emb + + # ----------------------------------------------------------------------- + # Image encoding + # ----------------------------------------------------------------------- + + def _encode_image( + self, + image: torch.Tensor, + num_frames: int, + height: int, + width: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Encode first frame via CLIP + VAE. + Returns: (clip_feas, ys, image_latent) + """ + device = image.device + batch_size = image.shape[0] + + with torch.amp.autocast(dtype=torch.bfloat16, device_type=device.type): + clip_context = self.image_encoder.encode_image(image) + + msk = torch.ones(batch_size, num_frames, height // 8, width // 8, device=device) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(batch_size, msk.shape[1] // 4, 4, height // 8, width // 8) + msk = msk.transpose(1, 2) + + latent_dtype = image.dtype + image_input = image.transpose(1, 2) + image_zeros = torch.zeros( + batch_size, + 3, + num_frames - 1, + height, + width, + dtype=latent_dtype, + device=device, + ) + vae_input = torch.concat([image_input, image_zeros], dim=2) + y = self._encode_vae_latents(vae_input) + y = y.to(dtype=latent_dtype) + + new_image = y[:, :, 0:1] + y = torch.concat([msk, y], dim=1) + + return clip_context, y, new_image + + def _encode_vae_latents(self, videos: torch.Tensor) -> torch.Tensor: + """Encode videos into normalized VAE latents.""" + input_dtype = videos.dtype + hidden = self.vae._encode(videos.to(dtype=self.vae.dtype)) + mu, _ = hidden.chunk(2, dim=1) + mean = self.vae_latents_mean.to(device=mu.device, dtype=mu.dtype) + inv_std = self.vae_latents_inv_std.to(device=mu.device, dtype=mu.dtype) + mu = (mu - mean) * inv_std + return mu.to(dtype=input_dtype) + + def decode_video_latents(self, video_latents: torch.Tensor) -> torch.Tensor: + """Decode normalized VAE latents into RGB video tensors.""" + vae_dtype = self.vae.dtype + vae_device = next(self.vae.parameters()).device + latents = video_latents.to(device=vae_device, dtype=vae_dtype) + mean = self.vae_latents_mean.to(device=vae_device, dtype=vae_dtype) + inv_std = self.vae_latents_inv_std.to(device=vae_device, dtype=vae_dtype) + latents = latents / inv_std + mean + with torch.no_grad(): + return self.vae.decode(latents, return_dict=False)[0] + + # ----------------------------------------------------------------------- + # KV cache prefill + # ----------------------------------------------------------------------- + + def _prefill_kv_cache( + self, + image_latents: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor | None, + frame_seqlen: int, + seq_len: int, + do_true_cfg: bool, + state: DreamZeroState, + ) -> None: + """Prefill KV cache with first frame and/or current observation. + + Uses predict_noise_maybe_with_cfg() for CFG parallel — same path as + the denoise loop. The mixin handles rank dispatch automatically. + KV cache update happens as a side effect inside predict_noise(). + """ + batch_size = image_latents.shape[0] + device = image_latents.device + dtype = image_latents.dtype + num_heads = getattr(self.transformer.blocks[0].self_attn, "tp_num_heads", self.transformer.num_heads) + head_dim = self.transformer.dim // self.transformer.num_heads + + if state.current_start_frame == 0: + state.create_kv_caches( + batch_size, + dtype, + device, + self.transformer.num_layers, + num_heads, + head_dim, + ) + + zero_t = torch.zeros([batch_size, 1], device=device, dtype=torch.long) + y_first = state.ys[:, :, 0:1] if state.ys is not None else None + + # KV cache update is a side effect in predict_noise() + common = dict( + hidden_states=image_latents.transpose(1, 2), + timestep_video=zero_t, + seq_len=frame_seqlen, + current_start_frame=0, + y=y_first, + clip_feature=state.clip_feas, + update_kv_cache=True, + dreamzero_state=state, + ) + positive_kwargs = dict( + encoder_hidden_states=prompt_embeds, + kv_cache=state.get_kv_caches(False), + crossattn_cache=state.get_crossattn_caches(False), + is_negative=False, + **common, + ) + negative_kwargs = ( + dict( + encoder_hidden_states=negative_prompt_embeds, + kv_cache=state.get_kv_caches(True), + crossattn_cache=state.get_crossattn_caches(True), + is_negative=True, + **common, + ) + if negative_prompt_embeds is not None + else None + ) + + self.predict_noise_maybe_with_cfg( + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + do_true_cfg=do_true_cfg, + true_cfg_scale=self.cfg_scale, + cfg_normalize=False, + ) + state.current_start_frame = 1 + + if state.current_start_frame != 1: + csf = state.current_start_frame + nfpb = self.num_frame_per_block + current_ref = image_latents[:, -nfpb:] + if state.ys is not None and csf <= state.ys.shape[2]: + y = state.ys[:, :, csf - nfpb : csf] + elif state.ys is not None: + y = state.ys[:, :, -nfpb:] + else: + y = None + + zero_t = torch.zeros([batch_size, nfpb], device=device, dtype=torch.long) + common = dict( + hidden_states=current_ref.transpose(1, 2), + timestep_video=zero_t, + seq_len=seq_len, + current_start_frame=csf - nfpb, + y=y, + clip_feature=state.clip_feas, + update_kv_cache=True, + dreamzero_state=state, + ) + positive_kwargs = dict( + encoder_hidden_states=prompt_embeds, + kv_cache=state.get_kv_caches(False), + crossattn_cache=state.get_crossattn_caches(False), + is_negative=False, + **common, + ) + negative_kwargs = ( + dict( + encoder_hidden_states=negative_prompt_embeds, + kv_cache=state.get_kv_caches(True), + crossattn_cache=state.get_crossattn_caches(True), + is_negative=True, + **common, + ) + if negative_prompt_embeds is not None + else None + ) + + self.predict_noise_maybe_with_cfg( + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + do_true_cfg=do_true_cfg, + true_cfg_scale=self.cfg_scale, + cfg_normalize=False, + ) + + def diffuse( + self, + video_latents: torch.Tensor, + action_latents: torch.Tensor, + timesteps_video: torch.Tensor, + timesteps_action: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor | None, + video_action_scheduler: VideoActionScheduler, + do_true_cfg: bool, + state: DreamZeroState, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Denoising loop with CFG parallel support. + + For each timestep: + 1. Build positive_kwargs / negative_kwargs + 2. predict_noise_maybe_with_cfg() → (video_pred, action_pred) + 3. scheduler_step_maybe_with_cfg() → VideoActionScheduler + 4. _synchronize_cfg_parallel_step_output() + """ + seq_len = kwargs["seq_len"] + state_features = kwargs.get("state_features") + embodiment_id = kwargs.get("embodiment_id") + + # Shared kwargs for predict_noise (both cond & uncond branches) + common_kwargs = dict( + seq_len=seq_len, + current_start_frame=state.current_start_frame, + state_features=state_features, + embodiment_id=embodiment_id, + update_kv_cache=False, + dreamzero_state=state, + ) + + noisy_input = video_latents + noisy_input_action = action_latents + for index in range(len(timesteps_video)): + video_timestep = timesteps_video[index] + action_timestep = timesteps_action[index] + batch_size = noisy_input.shape[0] + + timestep = ( + torch.ones( + [batch_size, self.num_frame_per_block], + device=noisy_input.device, + dtype=torch.int64, + ) + * video_timestep + ) + timestep_action = ( + torch.ones( + [batch_size, self.action_horizon], + device=noisy_input.device, + dtype=torch.int64, + ) + * action_timestep + ) + + csf = state.current_start_frame + if csf + self.num_frame_per_block <= state.ys.shape[2]: + y = state.ys[:, :, csf : csf + self.num_frame_per_block] + else: + y = state.ys[:, :, -self.num_frame_per_block :] + + positive_kwargs = dict( + hidden_states=noisy_input.transpose(1, 2), + timestep_video=timestep, + encoder_hidden_states=prompt_embeds, + kv_cache=state.get_kv_caches(False), + crossattn_cache=state.get_crossattn_caches(False), + y=y, + clip_feature=state.clip_feas, + action=noisy_input_action, + timestep_action=timestep_action, + is_negative=False, + **common_kwargs, + ) + + # Negative (uncond) kwargs + if do_true_cfg and negative_prompt_embeds is not None: + negative_kwargs = dict( + hidden_states=noisy_input.transpose(1, 2), + timestep_video=timestep, + encoder_hidden_states=negative_prompt_embeds, + kv_cache=state.get_kv_caches(True), + crossattn_cache=state.get_crossattn_caches(True), + y=y, + clip_feature=state.clip_feas, + action=noisy_input_action, + timestep_action=timestep_action, + is_negative=True, + **common_kwargs, + ) + else: + negative_kwargs = None + + noise_pred = self.predict_noise_maybe_with_cfg( + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + do_true_cfg=do_true_cfg, + true_cfg_scale=self.cfg_scale, + cfg_normalize=False, + ) + flow_pred, flow_pred_action = noise_pred + + latents = (noisy_input, noisy_input_action) + t = (video_timestep, action_timestep) + noise_pred_tuple = (flow_pred.transpose(1, 2), flow_pred_action) + step_output = video_action_scheduler.step( + noise_pred_tuple, + t, + latents, + generator=kwargs.get("generator"), + ) + noisy_input, noisy_input_action = step_output[0] + + noisy_input, noisy_input_action = self._synchronize_cfg_parallel_step_output( + (noisy_input, noisy_input_action), + do_true_cfg, + ) + + return noisy_input, noisy_input_action + + # ----------------------------------------------------------------------- + # Main entry point + # ----------------------------------------------------------------------- + + def _transform_robot_obs(self, robot_obs: dict): + """Select DreamZero robot transform and convert raw obs to model input.""" + embodiment = robot_obs.get("embodiment", self.default_robot_embodiment) + transform = get_transform(embodiment) + return transform, transform.transform_input(robot_obs) + + @torch.no_grad() + def forward(self, req: OmniDiffusionRequest, **kwargs) -> DiffusionOutput: + """Full inference step. Called by DiffusionEngine.step().""" + extra_args = req.sampling_params.extra_args or {} + robot_obs = extra_args.get("robot_obs") + if robot_obs is None: + first_prompt = req.prompts[0] if req.prompts else "" + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + is_dummy_warmup = prompt == "dummy run" and req.sampling_params.num_inference_steps == 1 + if is_dummy_warmup: + logger.info("Skipping DreamZero dummy warmup request without robot_obs.") + return DiffusionOutput( + output={ + "actions": np.zeros( + (self.action_horizon, self.max_action_dim), + dtype=np.float32, + ), + }, + ) + raise KeyError("robot_obs") + session_id = str(extra_args.get("session_id") or "default") + state = self._get_or_create_state(session_id) + self.state = state + transform, unified_obs = self._transform_robot_obs(robot_obs) + device = get_local_device() + + # ---- Step 1: Extract inputs from unified observation ---- + prompt_str = unified_obs["prompt"] # str (templated) + stitched = unified_obs["images"] # ndarray (T,H,W,C) from transform + if not isinstance(stitched, np.ndarray): + stitched = np.asarray(stitched) + embodiment_name = unified_obs["embodiment_name"] + embodiment_id = torch.tensor( # (B,) tensor for CategorySpecificMLP + [self.embodiment_name_to_id[embodiment_name]], + dtype=torch.long, + device=device, + ) + + # State: raw from transform → pad to (B, state_horizon=1, max_state_dim) + raw_state = unified_obs["state"] + state_for_postprocess = None + if raw_state is not None: + if not isinstance(raw_state, np.ndarray): + raw_state = np.asarray(raw_state, dtype=np.float64) + raw_state = raw_state.flatten() + padded = np.zeros(self.max_state_dim, dtype=np.float64) + n = min(len(raw_state), self.max_state_dim) + padded[:n] = raw_state[:n] + state_for_postprocess = ( + torch.from_numpy(padded) + .reshape(1, 1, self.max_state_dim) + .to( + device=device, + dtype=torch.float32, + ) + ) + state_features = self._normalize_state( + state_for_postprocess, + embodiment_name, + ).to(dtype=torch.bfloat16) + else: + state_features = None + + # ---- Step 1b: Tokenize ---- (wan2_2 convention: pipeline owns tokenizer) + text_inputs = self.tokenizer( + prompt_str, + max_length=512, + padding="max_length", + truncation=True, + return_tensors="pt", + add_special_tokens=True, + ) + text_tokens = text_inputs["input_ids"].to(device) + attention_mask = text_inputs["attention_mask"].to(device) + + # Explicit reset from OpenPI serving is carried by `extra_args["reset"]` + # on the next inference request after websocket reset/session switch. + if extra_args.get("reset", False): + state.reset() + # Auto-reset based on model state (before accumulation) + if state.should_reset(text_tokens, 0, self.transformer.local_attn_size): + state.reset() + state.language = text_tokens + + # Frame accumulation: stitched single frame → multi-frame video + video_frames = state.accumulate_frames(stitched) # (T, H, W, C) + videos = torch.from_numpy(video_frames).unsqueeze(0).to(device) # (B=1, T, H, W, C) + + videos = self._preprocess_video(videos) # → [B,C,T,H,W] bf16 + _, _, num_frames_raw, height, width = videos.shape + + prompt_embeds = self._encode_text(text_tokens, attention_mask) + # Negative prompt for CFG uncond branch (model constant) + negative_prompt_embeds = None + if self.cfg_scale > 1.0: + neg_inputs = self.tokenizer( + self.negative_prompt, + max_length=512, + padding="max_length", + truncation=True, + return_tensors="pt", + add_special_tokens=True, + ) + negative_prompt_embeds = self._encode_text( + neg_inputs["input_ids"].to(device), + neg_inputs["attention_mask"].to(device), + ) + + # Extract first/last frame for CLIP + VAE encoding + if num_frames_raw == 4 or num_frames_raw == 9: + image = videos[:, :, -1:].transpose(1, 2) + else: + image = videos[:, :, :1].transpose(1, 2) + + if state.current_start_frame == 0: + clip_feas, ys, image = self._encode_image( + image, + self.num_frames, + height, + width, + ) + state.clip_feas = clip_feas.to(dtype=image.dtype) + state.ys = ys.to(dtype=image.dtype) + + if state.current_start_frame != 0: + # Subsequent calls: encode current observation via VAE + if (num_frames_raw - 1) // 4 == self.num_frame_per_block: + pass + elif num_frames_raw // 4 != self.num_frame_per_block: + repeat_factor = self.num_frame_per_block // (num_frames_raw // 4) + videos = torch.repeat_interleave(videos, repeat_factor, dim=2) + first_frame = videos[:, :, 0:1] + videos = torch.cat([first_frame, videos], dim=2) + else: + first_frame = videos[:, :, 0:1] + videos = torch.cat([first_frame, videos], dim=2) + + latent_dtype = videos.dtype + with torch.no_grad(): + image = self._encode_vae_latents(videos) + image = image.to(dtype=latent_dtype) + + batch_size = image.shape[0] + generator = torch.Generator(device=device).manual_seed(self.seed) + noise_obs = torch.randn( + batch_size, + 16, + self.num_frame_per_block, + height // 8, + width // 8, + device=device, + dtype=torch.bfloat16, + generator=generator, + ) + generator = torch.Generator(device=device).manual_seed(self.seed) + noise_action = torch.randn( + batch_size, + self.action_horizon, + self.transformer.action_dim, + device=device, + dtype=torch.bfloat16, + generator=generator, + ) + + _, num_channels, num_frames, h_latent, w_latent = noise_obs.shape + frame_seqlen = int(h_latent * w_latent / 4) + seq_len = frame_seqlen * num_frames + + image = image.transpose(1, 2) + noise_obs = noise_obs.transpose(1, 2) + + do_true_cfg = self.cfg_scale > 1.0 and negative_prompt_embeds is not None + self._prefill_kv_cache( + image, + prompt_embeds, + negative_prompt_embeds, + frame_seqlen, + seq_len, + do_true_cfg, + state, + ) + + sample_scheduler = copy.deepcopy(self.scheduler) + sample_scheduler_action = copy.deepcopy(self.scheduler) + sample_scheduler.set_timesteps( + self.num_inference_steps, + device=device, + shift=self.sigma_shift, + ) + sample_scheduler_action.set_timesteps( + self.num_inference_steps, + device=device, + shift=self.sigma_shift, + ) + + if self.decouple_inference_noise: + video_final_noise = self.video_inference_final_noise + sigma_max = sample_scheduler.sigmas[0].item() + sample_scheduler.sigmas = ( + sample_scheduler.sigmas * (sigma_max - video_final_noise) / sigma_max + video_final_noise + ) + sample_scheduler.timesteps = (sample_scheduler.sigmas[:-1] * 1000).to(torch.int64) + + video_action_scheduler = VideoActionScheduler( + sample_scheduler, + sample_scheduler_action, + ) + + video_out, action_out = self.diffuse( + video_latents=noise_obs, + action_latents=noise_action, + timesteps_video=sample_scheduler.timesteps, + timesteps_action=sample_scheduler_action.timesteps, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + video_action_scheduler=video_action_scheduler, + do_true_cfg=do_true_cfg, + state=state, + seq_len=seq_len, + state_features=state_features, + embodiment_id=embodiment_id, + ) + + if state.current_start_frame == 1: + video_out = torch.cat([image, video_out], dim=1) + state.current_start_frame += self.num_frame_per_block + + # q99 denorm: [-1,1] → real values + action_out = self._denormalize_action(action_out.float(), embodiment_name) + + # Relative → absolute: only for relative_action_keys (joint_position only) + # gripper_position is NOT relative, so don't add state back to it + if self.relative_action and state_for_postprocess is not None: + n_relative = self.relative_action_dim # 7 for DROID (joint only) + # Use original state precision for post-denorm absolute recovery. + # Upstream adds obs state after `eval_transform.unapply()` + # the bf16 denoising path. + last_state = state_for_postprocess[:, 0, :n_relative] # (B, n_relative) + action_out[..., :n_relative] = ( + action_out[..., :n_relative] + last_state.unsqueeze(1) # broadcast over horizon + ) + + # Squeeze batch dim for output: (B, horizon, dim) → (horizon, dim) + actions_np = action_out.squeeze(0).float().cpu().numpy() # (horizon, max_action_dim) + actions_np = transform.transform_action_output(actions_np) + + return DiffusionOutput( + output={ + "actions": actions_np, + # Source `video_pred` is normalized VAE latent output, not RGB. + # Use `decode_video_latents()` for DreamZero-equivalent debug + # video decoding. + "video": video_out.transpose(1, 2).cpu(), + }, + ) + + # ----------------------------------------------------------------------- + # Action denormalization + # ----------------------------------------------------------------------- + + def _load_action_norm_stats(self, stats_path: str) -> dict[str, dict[str, torch.Tensor]]: + """Load per-embodiment action normalization stats from metadata.json. + + Returns: {embodiment_name: {"q01": Tensor(action_dim,), "q99": Tensor(action_dim,)}} + """ + with open(stats_path) as f: + metadata = json.load(f) + return self._parse_action_norm_stats(metadata) + + @staticmethod + def _parse_action_norm_stats(metadata: dict) -> dict[str, dict[str, torch.Tensor]]: + result = {} + for emb_name, emb_data in metadata.items(): + action_stats = emb_data.get("statistics", {}).get("action", {}) + q01_parts, q99_parts = [], [] + # Concatenate joint_position + gripper_position stats + for key in ["joint_position", "gripper_position"]: + if key in action_stats: + q01_parts.extend(action_stats[key]["q01"]) + q99_parts.extend(action_stats[key]["q99"]) + if q01_parts: + result[emb_name] = { + "q01": torch.tensor(q01_parts, dtype=torch.float32), + "q99": torch.tensor(q99_parts, dtype=torch.float32), + } + return result + + @staticmethod + def _parse_state_norm_stats(metadata: dict) -> dict[str, dict[str, torch.Tensor]]: + """Load per-embodiment state normalization stats from metadata.json.""" + result = {} + for emb_name, emb_data in metadata.items(): + state_stats = emb_data.get("statistics", {}).get("state", {}) + q01_parts, q99_parts = [], [] + for key in ["joint_position", "gripper_position"]: + if key in state_stats: + q01_parts.extend(state_stats[key]["q01"]) + q99_parts.extend(state_stats[key]["q99"]) + if q01_parts: + result[emb_name] = { + "q01": torch.tensor(q01_parts, dtype=torch.float32), + "q99": torch.tensor(q99_parts, dtype=torch.float32), + } + return result + + def _normalize_state( + self, + state: torch.Tensor, + embodiment_name: str, + ) -> torch.Tensor: + """Normalize state with q99 stats before feeding the model.""" + state_norm_stats = getattr(self, "state_norm_stats", {}) + if embodiment_name not in state_norm_stats: + return state + stats = state_norm_stats[embodiment_name] + q01 = stats["q01"].to(device=state.device, dtype=state.dtype) + q99 = stats["q99"].to(device=state.device, dtype=state.dtype) + actual_dim = q01.shape[0] + normalized = state.clone() + range_vals = q99 - q01 + mask = range_vals != 0 + normalized_slice = normalized[..., :actual_dim] + normalized_slice[..., mask] = 2 * (normalized_slice[..., mask] - q01[mask]) / range_vals[mask] - 1 + normalized_slice = torch.clamp(normalized_slice, -1, 1) + normalized[..., :actual_dim] = normalized_slice + return normalized + + def _denormalize_action( + self, + action: torch.Tensor, + embodiment_name: str, + ) -> torch.Tensor: + """Denormalize action from [-1,1] to real values using q99 mode. + + Formula: real = (normalized + 1) / 2 * (q99 - q01) + q01 + """ + if embodiment_name not in self.action_norm_stats: + return action + stats = self.action_norm_stats[embodiment_name] + q01 = stats["q01"].to(device=action.device, dtype=action.dtype) + q99 = stats["q99"].to(device=action.device, dtype=action.dtype) + # action shape: (B, horizon, action_dim) or (B, horizon, max_action_dim) + # q01/q99 shape: (actual_action_dim,) — only denorm actual dims + actual_dim = q01.shape[0] + action_real = action.clone() + action_real[..., :actual_dim] = (action[..., :actual_dim] + 1) / 2 * (q99 - q01) + q01 + return action_real + + # ----------------------------------------------------------------------- + # Weight loading + # ----------------------------------------------------------------------- + + @property + def weights_sources(self): + """ComponentSource list for DiffusersPipelineLoader.""" + return self._weights_sources + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load checkpoint weights with key remapping.""" + loaded: set[str] = set() + params = dict(self.named_parameters()) + buffers = dict(self.named_buffers()) + + for name, tensor in weights: + if name.startswith("action_head.model."): + new_name = "transformer." + name[len("action_head.model.") :] + new_name = ( + new_name.replace("img_emb.proj.0.", "img_emb.norm1.") + .replace("img_emb.proj.1.", "img_emb.fc1.") + .replace("img_emb.proj.3.", "img_emb.fc2.") + .replace("img_emb.proj.4.", "img_emb.norm2.") + ) + if new_name in params: + param = params[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, tensor) + loaded.add(new_name) + elif new_name in buffers: + buffers[new_name].data.copy_(tensor) + loaded.add(new_name) + + elif name.startswith("action_head.text_encoder."): + mapped = self._remap_text_encoder_key(name) + if mapped is None: + continue + for new_name in mapped if isinstance(mapped, list) else [mapped]: + full_name = "text_encoder." + new_name + if full_name in params: + params[full_name].data.copy_(tensor) + loaded.add(full_name) + + elif name.startswith("action_head.image_encoder."): + self._remap_image_encoder_key(name, tensor, params, loaded) + + elif name.startswith("action_head.vae."): + mapped = self._remap_vae_key(name) + if mapped is None: + continue + full_name = "vae." + mapped + if full_name in params: + params[full_name].data.copy_(tensor) + loaded.add(full_name) + + logger.info( + "DreamZero load_weights: loaded %d parameters from root checkpoint", + len(loaded), + ) + return loaded + + # ----------------------------------------------------------------------- + # Text encoder key remapping + # ----------------------------------------------------------------------- + + @staticmethod + def _remap_text_encoder_key(name: str) -> str | list[str] | None: + """Remap a single text encoder key.""" + subkey = name[len("action_head.text_encoder.") :] + + if subkey == "token_embedding.weight": + return "shared.weight" + if subkey == "norm.weight": + return "encoder.final_layer_norm.weight" + + m = re_module.match(r"blocks\.(\d+)\.(.*)", subkey) + if not m: + return None + block_idx = m.group(1) + rest = m.group(2) + + prefix = f"encoder.block.{block_idx}" + + if rest == "attn.q.weight": + return f"{prefix}.layer.0.SelfAttention.q.weight" + if rest == "attn.k.weight": + return f"{prefix}.layer.0.SelfAttention.k.weight" + if rest == "attn.v.weight": + return f"{prefix}.layer.0.SelfAttention.v.weight" + if rest == "attn.o.weight": + return f"{prefix}.layer.0.SelfAttention.o.weight" + if rest == "pos_embedding.embedding.weight": + return f"{prefix}.layer.0.SelfAttention.relative_attention_bias.weight" + if rest == "norm1.weight": + return f"{prefix}.layer.0.layer_norm.weight" + + if rest == "ffn.gate.0.weight": + return f"{prefix}.layer.1.DenseReluDense.wi_0.weight" + if rest == "ffn.fc1.weight": + return f"{prefix}.layer.1.DenseReluDense.wi_1.weight" + if rest == "ffn.fc2.weight": + return f"{prefix}.layer.1.DenseReluDense.wo.weight" + if rest == "norm2.weight": + return f"{prefix}.layer.1.layer_norm.weight" + + return None + + # ----------------------------------------------------------------------- + # VAE key remapping + # ----------------------------------------------------------------------- + + @staticmethod + def _remap_vae_key(name: str) -> str | None: + """Remap DreamZero VAE keys to `DistributedAutoencoderKLWan` keys.""" + if not name.startswith("action_head.vae.model."): + return None + + rest = name[len("action_head.vae.model.") :] + + direct_prefix_map = { + "encoder.conv1.": "encoder.conv_in.", + "encoder.head.0.": "encoder.norm_out.", + "encoder.head.2.": "encoder.conv_out.", + "decoder.conv1.": "decoder.conv_in.", + "decoder.head.0.": "decoder.norm_out.", + "decoder.head.2.": "decoder.conv_out.", + "conv1.": "quant_conv.", + "conv2.": "post_quant_conv.", + } + for src_prefix, dst_prefix in direct_prefix_map.items(): + if rest.startswith(src_prefix): + return dst_prefix + rest[len(src_prefix) :] + + resnet_leaf_map = { + "residual.0.gamma": "norm1.gamma", + "residual.2.weight": "conv1.weight", + "residual.2.bias": "conv1.bias", + "residual.3.gamma": "norm2.gamma", + "residual.6.weight": "conv2.weight", + "residual.6.bias": "conv2.bias", + } + block_leaf_map = { + **resnet_leaf_map, + "shortcut.weight": "conv_shortcut.weight", + "shortcut.bias": "conv_shortcut.bias", + "resample.1.weight": "resample.1.weight", + "resample.1.bias": "resample.1.bias", + "time_conv.weight": "time_conv.weight", + "time_conv.bias": "time_conv.bias", + } + + m = re_module.match(r"encoder\.middle\.(\d+)\.(.*)", rest) + if m: + idx = int(m.group(1)) + tail = m.group(2) + if idx in (0, 2) and tail in resnet_leaf_map: + res_idx = 0 if idx == 0 else 1 + return f"encoder.mid_block.resnets.{res_idx}.{resnet_leaf_map[tail]}" + if idx == 1: + return f"encoder.mid_block.attentions.0.{tail}" + return None + + m = re_module.match(r"decoder\.middle\.(\d+)\.(.*)", rest) + if m: + idx = int(m.group(1)) + tail = m.group(2) + if idx in (0, 2) and tail in resnet_leaf_map: + res_idx = 0 if idx == 0 else 1 + return f"decoder.mid_block.resnets.{res_idx}.{resnet_leaf_map[tail]}" + if idx == 1: + return f"decoder.mid_block.attentions.0.{tail}" + return None + + m = re_module.match(r"encoder\.downsamples\.(\d+)\.(.*)", rest) + if m: + idx = int(m.group(1)) + tail = m.group(2) + if tail in block_leaf_map: + return f"encoder.down_blocks.{idx}.{block_leaf_map[tail]}" + return None + + m = re_module.match(r"decoder\.upsamples\.(\d+)\.(.*)", rest) + if m: + idx = int(m.group(1)) + tail = m.group(2) + if tail not in block_leaf_map: + return None + + if idx <= 2: + prefix = f"decoder.up_blocks.0.resnets.{idx}." + elif idx == 3: + prefix = "decoder.up_blocks.0.upsamplers.0." + elif 4 <= idx <= 6: + prefix = f"decoder.up_blocks.1.resnets.{idx - 4}." + elif idx == 7: + prefix = "decoder.up_blocks.1.upsamplers.0." + elif 8 <= idx <= 10: + prefix = f"decoder.up_blocks.2.resnets.{idx - 8}." + elif idx == 11: + prefix = "decoder.up_blocks.2.upsamplers.0." + elif 12 <= idx <= 14: + prefix = f"decoder.up_blocks.3.resnets.{idx - 12}." + else: + return None + return prefix + block_leaf_map[tail] + + return None + + # ----------------------------------------------------------------------- + # Image encoder key remapping + # ----------------------------------------------------------------------- + + def _remap_image_encoder_key( + self, + name: str, + tensor: torch.Tensor, + params: dict[str, torch.nn.Parameter], + loaded: set[str], + ) -> None: + """Map an image encoder key onto the local module.""" + if not name.startswith("action_head.image_encoder."): + return + + full_name = "image_encoder." + name[len("action_head.image_encoder.") :] + if full_name in params: + params[full_name].data.copy_(tensor) + loaded.add(full_name) diff --git a/vllm_omni/diffusion/models/dreamzero/state_dreamzero.py b/vllm_omni/diffusion/models/dreamzero/state_dreamzero.py new file mode 100644 index 00000000000..0972d1b5a2c --- /dev/null +++ b/vllm_omni/diffusion/models/dreamzero/state_dreamzero.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""DreamZero pipeline persistent state.""" + +from __future__ import annotations + +import logging +from collections import deque + +import numpy as np +import torch + +logger = logging.getLogger(__name__) + +# Number of frames per chunk for subsequent calls (first call uses 1) +FRAMES_PER_CHUNK = 4 + + +class DreamZeroState: + """Pipeline persistent state across forward() calls. + + Lifecycle: + - Created once in DreamZeroPipeline.__init__() + - Mutated every forward() call (frame append, KV cache grow) + - reset() on new session / language change / local_attn_size exceeded + """ + + def __init__(self) -> None: + self.reset() + + # ------------------------------------------------------------------ + # Frame accumulation (single stitched buffer) + # Transform outputs stitched single frame per call. + # We accumulate here to build multi-frame video for AR inference. + # ------------------------------------------------------------------ + + def accumulate_frames(self, stitched: np.ndarray) -> np.ndarray: + """Accumulate stitched frames and return multi-frame video. + + Args: + stitched: (H, W, C) single frame or (T, H, W, C) multi-frame, + already stitched by transform. + + Returns: + (T, H, W, C) ndarray. T=1 for first call, T=FRAMES_PER_CHUNK(4) after. + """ + if stitched.ndim == 3: + self.stitched_buffer.append(stitched) + elif stitched.ndim == 4: + self.stitched_buffer.extend(list(stitched)) + else: + raise ValueError(f"Expected 3D or 4D stitched, got {stitched.ndim}D") + + num_frames = 1 if self.call_count == 0 else FRAMES_PER_CHUNK + + buffer_frames = list(self.stitched_buffer) + if len(buffer_frames) >= num_frames: + frames = buffer_frames[-num_frames:] + else: + frames = buffer_frames + while len(frames) < num_frames: + frames.insert(0, buffer_frames[0]) + + self.call_count += 1 + return np.stack(frames, axis=0) # (T, H, W, C) + + # ------------------------------------------------------------------ + # Reset / should_reset + # ------------------------------------------------------------------ + + def reset(self) -> None: + """Clear all state.""" + self.stitched_buffer: deque[np.ndarray] = deque(maxlen=FRAMES_PER_CHUNK) + self.call_count: int = 0 + + # KV cache once robot-policy diffusion supports that integration. + self.kv_cache: list[torch.Tensor] | None = None + self.kv_cache_neg: list[torch.Tensor] | None = None + self.crossattn_cache: list[dict[str, bool | torch.Tensor | None]] | None = None + self.crossattn_cache_neg: list[dict[str, bool | torch.Tensor | None]] | None = None + self.current_start_frame: int = 0 + + self.clip_feas: torch.Tensor | None = None + self.ys: torch.Tensor | None = None + self.language: torch.Tensor | None = None + + def should_reset(self, text_tokens: torch.Tensor | None, num_video_frames: int, local_attn_size: int) -> bool: + """Determine if state should be reset before this forward().""" + if self.language is None: + logger.info("language is None, resetting") + return True + + if text_tokens is not None and not torch.equal(self.language, text_tokens): + logger.info("language changed, resetting") + return True + + # NOTE: after accumulate_frames, num_video_frames is the accumulated T + # (1 for first call, 4 for subsequent). Only reset on true single-frame + # which happens when the stitched_buffer was cleared externally. + if num_video_frames == 1 and self.call_count > 1: + logger.info("single frame input after first call, resetting") + return True + + if local_attn_size != -1 and self.current_start_frame >= local_attn_size: + logger.info( + "current_start_frame %d >= local_attn_size %d, resetting", self.current_start_frame, local_attn_size + ) + return True + + return False + + # ------------------------------------------------------------------ + # KV cache management + # ------------------------------------------------------------------ + + def create_kv_caches( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + num_layers: int, + num_heads: int, + head_dim: int, + ) -> None: + """Initialize empty KV caches and cross-attention caches.""" + self.kv_cache = [ + torch.zeros(2, batch_size, 0, num_heads, head_dim, dtype=dtype, device=device) for _ in range(num_layers) + ] + self.kv_cache_neg = [ + torch.zeros(2, batch_size, 0, num_heads, head_dim, dtype=dtype, device=device) for _ in range(num_layers) + ] + + self.crossattn_cache = [{"is_init": False, "k": None, "v": None} for _ in range(num_layers)] + self.crossattn_cache_neg = [{"is_init": False, "k": None, "v": None} for _ in range(num_layers)] + + def update_kv_cache( + self, + layer_index: int, + updated_kv: torch.Tensor, + is_negative: bool = False, + ) -> None: + """Update a single layer's KV cache after prefill.""" + cache = self.kv_cache_neg if is_negative else self.kv_cache + if cache is None: + raise RuntimeError("KV caches not initialized, call create_kv_caches first.") + cache[layer_index] = updated_kv.clone() + + def get_kv_caches(self, is_negative: bool = False) -> list[torch.Tensor]: + """Get KV caches for the specified branch.""" + cache = self.kv_cache_neg if is_negative else self.kv_cache + if cache is None: + raise RuntimeError("KV caches not initialized.") + return cache + + def get_crossattn_caches(self, is_negative: bool = False) -> list[dict[str, bool | torch.Tensor | None]]: + """Get cross-attention caches for the specified branch.""" + cache = self.crossattn_cache_neg if is_negative else self.crossattn_cache + if cache is None: + raise RuntimeError("Cross-attn caches not initialized.") + return cache diff --git a/vllm_omni/diffusion/models/dreamzero/transform/__init__.py b/vllm_omni/diffusion/models/dreamzero/transform/__init__.py new file mode 100644 index 00000000000..4dec3a72b01 --- /dev/null +++ b/vllm_omni/diffusion/models/dreamzero/transform/__init__.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import importlib + +from vllm.logger import init_logger + +from vllm_omni.diffusion.models.dreamzero.transform.base import TRANSFORMS + +logger = init_logger(__name__) + +DEFAULT_EMBODIMENT = "roboarena" +_BUILTIN_TRANSFORM_MODULES = ( + "vllm_omni.diffusion.models.dreamzero.transform.droid", + "vllm_omni.diffusion.models.dreamzero.transform.roboarena", +) + + +def ensure_transforms_loaded() -> None: + """Import DreamZero transform modules and verify registration.""" + for module_name in _BUILTIN_TRANSFORM_MODULES: + try: + importlib.import_module(module_name) + except Exception as exc: + logger.exception("Failed to import DreamZero transform module %s", module_name) + raise RuntimeError(f"Failed to import DreamZero transform module '{module_name}'.") from exc + + if DEFAULT_EMBODIMENT not in TRANSFORMS: + raise RuntimeError(f"Built-in DreamZero transform '{DEFAULT_EMBODIMENT}' is not registered after import.") diff --git a/vllm_omni/diffusion/models/dreamzero/transform/base.py b/vllm_omni/diffusion/models/dreamzero/transform/base.py new file mode 100644 index 00000000000..dee78eb5450 --- /dev/null +++ b/vllm_omni/diffusion/models/dreamzero/transform/base.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Base transform interface for DreamZero robot policy serving. + +Transforms handle dataset-specific concerns ONLY: + - Observation key mapping + - Multi-view stitching (embodiment-specific layout) + - Language template wrapping (embodiment-specific) + - Raw state extraction (dataset-specific keys) + - Output action slicing (to actual action_dim) + +Model-specific concerns belong in the pipeline: + - Tokenization (pipeline owns tokenizer) + - State padding (pipeline knows MAX_STATE_DIM) + - Negative prompt (pipeline owns the string) + - Noise generation, encoding, decoding + +Flow: + raw obs (dataset format) + → DreamZeroPipeline selects transform by embodiment + → unified dict (stitched video, templated prompt str, raw state) + → tokenize, pad, encode, denoise + → transform_action_output() + → ndarray (N, action_dim) +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np + + +class RobotPolicyTransform: + """Base class for dataset-specific observation transforms. + + Subclasses MUST define: + IMAGE_KEY_MAP: dict — dataset obs keys → unified keys + EMBODIMENT_NAME: str — embodiment identity (pipeline maps to numeric ID) + ACTION_DIM: int — actual action dimensions (for output slicing) + + Subclasses MUST override: + _stitch_views() — multi-view → single stitched image + _language_template() — prompt → embodiment-aware template + _extract_raw_state() — obs → raw state ndarray + """ + + IMAGE_KEY_MAP: dict[str, str] + EMBODIMENT_NAME: str + ACTION_DIM: int + + def transform_input(self, obs: dict) -> dict: + """Dataset-specific transform: key map → stitch → template → state.""" + # 1. Map image keys → unified keys + images: dict[str, np.ndarray] = {} + for src_key, dst_key in self.IMAGE_KEY_MAP.items(): + if src_key in obs: + images[dst_key] = np.asarray(obs[src_key]) + + # 2. Multi-view stitching + stitched = self._stitch_views(images) + + # 3. Language template (string only, pipeline tokenizes) + prompt = obs.get("prompt", "") + templated_prompt = self._language_template(prompt) + + # 4. Raw state extraction (pipeline pads) + raw_state = self._extract_raw_state(obs) + + # 5. Build unified output + unified: dict[str, Any] = { + "images": stitched, # ndarray (T, H_out, W_out, 3) + "prompt": templated_prompt, # str (templated, not tokenized) + "state": raw_state, # ndarray (state_dim,) — pipeline pads + "embodiment_name": self.EMBODIMENT_NAME, + } + if "session_id" in obs: + unified["session_id"] = obs["session_id"] + return unified + + def transform_action_output(self, actions: Any) -> np.ndarray: + """Adapt model action output to this transform's action dimensions.""" + actions = np.asarray(actions, dtype=np.float32) + # Handle any remaining batch dims: squeeze to 2D (horizon, dim) + while actions.ndim > 2: + actions = actions[0] + # Slice padded dim to actual ACTION_DIM + if actions.ndim == 2 and actions.shape[-1] > self.ACTION_DIM: + actions = actions[:, : self.ACTION_DIM] + return actions + + # ------------------------------------------------------------------ + # Subclass MUST override + # ------------------------------------------------------------------ + + def _stitch_views(self, images: dict[str, np.ndarray]) -> np.ndarray: + """Stitch camera views into single image. + Input: unified key → ndarray (H,W,3) or (T,H,W,3). + Output: ndarray (T, H_out, W_out, 3). + """ + raise NotImplementedError + + def _language_template(self, prompt: str) -> str: + """Wrap prompt in embodiment-specific template string.""" + raise NotImplementedError + + def _extract_raw_state(self, obs: dict) -> np.ndarray: + """Extract raw state vector from obs. + Returns: ndarray (state_dim,) float64. Pipeline handles padding. + """ + raise NotImplementedError + + +# Transform registry — keyed by embodiment/dataset name +TRANSFORMS: dict[str, RobotPolicyTransform] = {} + + +def register_transform(name: str, transform: RobotPolicyTransform) -> None: + TRANSFORMS[name] = transform + + +def get_transform(name: str) -> RobotPolicyTransform: + if name not in TRANSFORMS: + raise KeyError(f"Unknown transform '{name}'. Available: {list(TRANSFORMS.keys())}") + return TRANSFORMS[name] diff --git a/vllm_omni/diffusion/models/dreamzero/transform/droid.py b/vllm_omni/diffusion/models/dreamzero/transform/droid.py new file mode 100644 index 00000000000..b116c54d56c --- /dev/null +++ b/vllm_omni/diffusion/models/dreamzero/transform/droid.py @@ -0,0 +1,138 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""DROID dataset transform. + +DROID uses 1-indexed exterior cameras, 3 views total (OXE_DROID embodiment). +Stitching layout (same as RoboArena — both are OXE_DROID): + ┌─────────────────────────┐ + │ wrist (2x width) │ ← pixel-repeat along width + ├────────────┬────────────┤ + │ left ext │ right ext │ + └────────────┴────────────┘ +""" + +from __future__ import annotations + +import numpy as np +import torch +import torchvision.transforms.v2 as T + +from vllm_omni.diffusion.models.dreamzero.transform.base import ( + RobotPolicyTransform, + register_transform, +) + + +class DroidTransform(RobotPolicyTransform): + """Transform for DROID dataset (OXE_DROID embodiment). + + DROID observation keys (1-indexed exterior cameras): + observation/exterior_image_1_left → left exterior + observation/exterior_image_2_left → right exterior + observation/wrist_image_left → wrist + """ + + IMAGE_KEY_MAP = { + "observation/exterior_image_1_left": "images/exterior_0", + "observation/exterior_image_2_left": "images/exterior_1", + "observation/wrist_image_left": "images/wrist", + } + EMBODIMENT_NAME = "oxe_droid" + ACTION_DIM = 8 # 7 joint + 1 gripper + _VIDEO_CROP_SCALE = 0.95 + _VIDEO_RESIZE_HW = (176, 320) + + @classmethod + def _preprocess_view(cls, arr: np.ndarray) -> np.ndarray: + """Apply per-view crop and resize before stitching.""" + frames = torch.from_numpy(arr).to(torch.float32).permute(0, 3, 1, 2) / 255.0 + crop_h = int(arr.shape[1] * cls._VIDEO_CROP_SCALE) + crop_w = int(arr.shape[2] * cls._VIDEO_CROP_SCALE) + frames = T.CenterCrop((crop_h, crop_w))(frames) + frames = T.Resize( + cls._VIDEO_RESIZE_HW, + interpolation=T.InterpolationMode.BILINEAR, + antialias=True, + )(frames) + return (frames.permute(0, 2, 3, 1) * 255.0).to(torch.uint8).cpu().numpy() + + def _stitch_views(self, images: dict[str, np.ndarray]) -> np.ndarray: + """OXE_DROID 2x2 stitching: wrist top (2x wide), exteriors bottom.""" + left_ext = images.get("images/exterior_0") + right_ext = images.get("images/exterior_1") + wrist = images.get("images/wrist") + + # Ensure 4D: (T, H, W, C) + def ensure_4d(arr: np.ndarray | None) -> np.ndarray | None: + if arr is None: + return None + return arr if arr.ndim == 4 else arr[np.newaxis] + + left_ext = ensure_4d(left_ext) + right_ext = ensure_4d(right_ext) + wrist = ensure_4d(wrist) + + # Determine shape from first available view. + ref = next((v for v in [wrist, left_ext, right_ext] if v is not None), None) + if ref is None: + return np.zeros((1, 352, 640, 3), dtype=np.uint8) + + # Apply per-view crop + resize before stitching. + def maybe_preprocess(arr: np.ndarray | None) -> np.ndarray | None: + if arr is None: + return None + return self._preprocess_view(arr) + + left_ext = maybe_preprocess(left_ext) + right_ext = maybe_preprocess(right_ext) + wrist = maybe_preprocess(wrist) + ref = next((v for v in [wrist, left_ext, right_ext] if v is not None), None) + if ref is None: + raise RuntimeError("Expected at least one DROID camera view after preprocessing.") + t, h, w, c = ref.shape + + out = np.zeros((t, 2 * h, 2 * w, c), dtype=ref.dtype) # (T, 2H, 2W, C) + + # Top row: wrist repeated 2x along width. + if wrist is not None: + wrist_wide = np.repeat(wrist, 2, axis=2) # (T, H, 2W, C) + out[:, :h, :] = wrist_wide + + # Bottom row: left exterior | right exterior. + if left_ext is not None: + out[:, h:, :w] = left_ext + if right_ext is not None: + out[:, h:, w:] = right_ext + + return out + + def _language_template(self, prompt: str) -> str: + """Expand the language prompt for the OXE_DROID multi-view format.""" + prompt = (prompt or "Perform the default behavior.").strip() + prompt_lower = prompt.lower() + return ( + "A multi-view video shows that a robot " + + prompt_lower + + " The video is split into three views: The top view shows the " + + "camera view from the robot's wrist, the bottom-left view shows " + + "the camera view from the left exterior camera, and the " + + "bottom-right view shows the camera view from the right exterior " + + "camera. During training, one of the two bottom exterior views " + + "may be a black screen (dropped view). The robot " + + prompt_lower + ) + + def _extract_raw_state(self, obs: dict) -> np.ndarray: + """OXE_DROID state: 7 joint + 1 gripper = 8 dims.""" + parts = [] + if "observation/joint_position" in obs: + parts.append(np.asarray(obs["observation/joint_position"], dtype=np.float64).flatten()) + if "observation/gripper_position" in obs: + parts.append(np.asarray(obs["observation/gripper_position"], dtype=np.float64).flatten()) + if parts: + return np.concatenate(parts) + return np.zeros(8, dtype=np.float64) + + +register_transform("droid", DroidTransform()) diff --git a/vllm_omni/diffusion/models/dreamzero/transform/roboarena.py b/vllm_omni/diffusion/models/dreamzero/transform/roboarena.py new file mode 100644 index 00000000000..d19bb0f8ce7 --- /dev/null +++ b/vllm_omni/diffusion/models/dreamzero/transform/roboarena.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""RoboArena dataset transform. + +RoboArena uses 0-indexed exterior cameras, 3 views total (OXE_DROID embodiment). +Same stitching layout as DROID — both map to OXE_DROID in DreamZero. + +""" + +from __future__ import annotations + +from vllm_omni.diffusion.models.dreamzero.transform.base import ( + register_transform, +) +from vllm_omni.diffusion.models.dreamzero.transform.droid import ( + DroidTransform, +) + + +class RoboArenaTransform(DroidTransform): + """Transform for RoboArena dataset. + + Same embodiment as DROID (OXE_DROID), same stitching and template. + Only difference: 0-indexed exterior camera keys. + + RoboArena observation keys (0-indexed): + observation/exterior_image_0_left → left exterior + observation/exterior_image_1_left → right exterior + observation/wrist_image_left → wrist + + """ + + IMAGE_KEY_MAP = { + "observation/exterior_image_0_left": "images/exterior_0", + "observation/exterior_image_1_left": "images/exterior_1", + "observation/wrist_image_left": "images/wrist", + } + + +register_transform("roboarena", RoboArenaTransform()) diff --git a/vllm_omni/diffusion/models/dreamzero/utils.py b/vllm_omni/diffusion/models/dreamzero/utils.py new file mode 100644 index 00000000000..88b97ac2252 --- /dev/null +++ b/vllm_omni/diffusion/models/dreamzero/utils.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""DreamZero model constants shared by the pipeline.""" + +DEFAULT_NUM_INFERENCE_STEPS = 16 +DEFAULT_CFG_SCALE = 5.0 +DEFAULT_SIGMA_SHIFT = 5.0 +DEFAULT_SEED = 1140 + +DEFAULT_NEGATIVE_PROMPT = ( + "Vibrant colors, overexposed, static, blurry details, text, subtitles, " + "style, artwork, painting, image, still, grayscale, dull, worst quality, " + "low quality, JPEG artifacts, ugly, mutilated, extra fingers, bad hands, " + "bad face, deformed, disfigured, mutated limbs, fused fingers, stagnant " + "image, cluttered background, three legs, many people in the background, " + "walking backwards." +) + +DEFAULT_EMBODIMENT_NAME_TO_ID = { + "oxe_droid": 17, + "agibot": 26, + "gr1_unified": 24, + "xdof": 22, + "yam": 32, + "mecka_hands": 27, + "lapa": 27, + "dream": 31, +} diff --git a/vllm_omni/diffusion/models/dreamzero/video_export_worker.py b/vllm_omni/diffusion/models/dreamzero/video_export_worker.py new file mode 100644 index 00000000000..20c8cd8cdc7 --- /dev/null +++ b/vllm_omni/diffusion/models/dreamzero/video_export_worker.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import torch + + +class DreamZeroVideoExportWorkerExtension: + """DreamZero worker RPCs used by offline example video export.""" + + def decode_video_latents_to_uint8(self, video_latents: torch.Tensor) -> torch.Tensor: + if self.model_runner is None or self.model_runner.pipeline is None: + raise RuntimeError("DreamZero pipeline is not initialized on this worker.") + + with torch.inference_mode(): + decoded = self.model_runner.pipeline.decode_video_latents(video_latents) + decoded = decoded.squeeze(0).permute(1, 2, 3, 0).contiguous() + decoded = decoded.clamp(-1, 1) * 0.5 + 0.5 + decoded = (decoded * 255.0).round().to(torch.uint8).cpu() + return decoded diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index d8302c11501..3ceb720f1da 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -266,6 +266,11 @@ "pipeline_hidream_image", "HiDreamImagePipeline", ), + "DreamZeroPipeline": ( + "dreamzero", + "pipeline_dreamzero", + "DreamZeroPipeline", + ), } diff --git a/vllm_omni/diffusion/utils/hf_utils.py b/vllm_omni/diffusion/utils/hf_utils.py index 6beb1823ce0..dad02583f08 100644 --- a/vllm_omni/diffusion/utils/hf_utils.py +++ b/vllm_omni/diffusion/utils/hf_utils.py @@ -1,4 +1,5 @@ import os +from collections.abc import Mapping from functools import lru_cache from vllm.logger import init_logger @@ -27,6 +28,31 @@ def _looks_like_bagel(model_name: str) -> bool: return False +def _looks_like_dreamzero(model_name: str) -> bool: + """Best-effort detection for DreamZero-style VLA diffusion checkpoints.""" + try: + cfg = get_hf_file_to_dict("config.json", model_name) + if cfg.get("model_type") != "vla": + return False + action_head_cfg = cfg.get("action_head_cfg") or {} + if not isinstance(action_head_cfg, Mapping): + return False + action_head_config = action_head_cfg.get("config") or {} + if not isinstance(action_head_config, Mapping): + return False + diffusion_model_cfg = action_head_config.get("diffusion_model_cfg") or {} + if not isinstance(diffusion_model_cfg, Mapping): + return False + return ( + action_head_cfg.get("_target_") + == "groot.vla.model.dreamzero.action_head.wan_flow_matching_action_tf.WANPolicyHead" + and diffusion_model_cfg.get("_target_") + == ("groot.vla.model.dreamzero.modules.wan_video_dit_action_casual_chunk.CausalWanModel") + ) + except Exception: + return False + + @lru_cache def is_diffusion_model(model_name: str) -> bool: """Check if a model is a diffusion model. @@ -72,6 +98,7 @@ def is_diffusion_model(model_name: str) -> bool: except Exception as e: logger.debug("Failed to load diffusers config via DiffusionPipeline: %s", e) - # Bagel is not a diffusers pipeline (no model_index.json), but is still a - # diffusion-style model in vllm-omni. Detect it via config.json. - return _looks_like_bagel(model_name) + # Bagel and DreamZero are not diffusers pipelines (no model_index.json), + # but are still diffusion-style models in vllm-omni. Detect them via + # config.json. + return _looks_like_bagel(model_name) or _looks_like_dreamzero(model_name) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 29788d95868..d406b3a0dc0 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -115,6 +115,7 @@ VideoListResponse, VideoResponse, ) +from vllm_omni.entrypoints.openai.realtime.robot.openpi_serving import ServingRealtimeRobotOpenPI from vllm_omni.entrypoints.openai.realtime_connection import RealtimeConnection from vllm_omni.entrypoints.openai.serving_audio_generate import OmniOpenAIServingAudioGenerate from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat @@ -629,6 +630,10 @@ async def omni_init_app_state( ) state.openai_streaming_speech = None state.openai_streaming_video = None + state.openai_serving_realtime_robot = ServingRealtimeRobotOpenPI.create_policy_server( + engine_client=engine_client, + model_name=model_name, + ) state.enable_server_load_tracking = getattr(args, "enable_server_load_tracking", False) state.server_load_metrics = 0 @@ -947,6 +952,8 @@ async def omni_init_app_state( stage_configs=state.stage_configs, ) + state.openai_serving_realtime_robot = None + state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0 state.sleeping_stages = set() @@ -1406,6 +1413,28 @@ async def realtime_websocket(websocket: WebSocket): await connection.handle_connection() +@router.websocket("/v1/realtime/robot/openpi") +async def realtime_robot_openpi(websocket: WebSocket): + """WebSocket endpoint for robot policy inference (OpenPI protocol). + + Binary frames: msgpack observation/action (OpenPI compatible). + Text frames: JSON control events (session.update, etc.). + See realtime.robot.openpi_connection.py for protocol details. + """ + from vllm_omni.entrypoints.openai.realtime.robot.openpi_connection import ( + RobotRealtimeConnection, + ) + + serving = getattr(websocket.app.state, "openai_serving_realtime_robot", None) + if serving is None: + await websocket.accept() + await websocket.send_json({"type": "error", "error": "Robot policy not available", "code": "unsupported"}) + await websocket.close() + return + connection = RobotRealtimeConnection(websocket, serving) + await connection.handle_connection() + + # Health and Model endpoints for diffusion mode diff --git a/vllm_omni/entrypoints/openai/realtime/robot/__init__.py b/vllm_omni/entrypoints/openai/realtime/robot/__init__.py new file mode 100644 index 00000000000..9881313609a --- /dev/null +++ b/vllm_omni/entrypoints/openai/realtime/robot/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/vllm_omni/entrypoints/openai/realtime/robot/openpi_connection.py b/vllm_omni/entrypoints/openai/realtime/robot/openpi_connection.py new file mode 100644 index 00000000000..2ede4fcd7ad --- /dev/null +++ b/vllm_omni/entrypoints/openai/realtime/robot/openpi_connection.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""WebSocket connection for robot policy inference (OpenPI protocol). + +Protocol (compatible with OpenPI policy clients): + Connect -> server sends msgpack(PolicyServerConfig fields) + Infer -> client sends msgpack(obs), server sends msgpack(ndarray) + Reset -> client sends msgpack({endpoint:reset}), server sends msgpack(status) +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +from fastapi import WebSocket +from starlette.websockets import WebSocketDisconnect +from vllm.logger import init_logger + +from vllm_omni.entrypoints.openai.realtime.robot.openpi_serving import ( + ServingRealtimeRobotOpenPI, +) + +logger = init_logger(__name__) +_DEFAULT_IDLE_TIMEOUT = 30.0 +MAX_OPENPI_PAYLOAD_BYTES = 64 * 1024 * 1024 + + +def _get_msgpack_numpy() -> Any: + try: + from openpi_client import msgpack_numpy + except ImportError as exc: + raise ImportError( + "The `/v1/realtime/robot/openpi` endpoint requires the optional " + "`openpi-client` dependency. Install it with `pip install openpi-client`." + ) from exc + + return msgpack_numpy + + +def _pack(obj: Any) -> bytes: + return _get_msgpack_numpy().packb(obj) + + +def _unpack(data: bytes) -> Any: + return _get_msgpack_numpy().unpackb(data) + + +class RobotRealtimeConnection: + """WebSocket connection for robot policy inference.""" + + def __init__( + self, + websocket: WebSocket, + serving: ServingRealtimeRobotOpenPI, + idle_timeout: float = _DEFAULT_IDLE_TIMEOUT, + ) -> None: + self.websocket = websocket + self.serving = serving + self._idle_timeout = idle_timeout + self._current_session_id: str | None = None + self._call_count = 0 + + def reset(self) -> None: + self._current_session_id = None + self._call_count = 0 + + async def _send_error(self, message: str) -> None: + await self.websocket.send_bytes(_pack({"type": "error", "message": message})) + + def _unpack_request(self, data: bytes) -> dict[str, Any]: + if len(data) > MAX_OPENPI_PAYLOAD_BYTES: + raise ValueError("OpenPI request payload too large") + obs = _unpack(data) + if not isinstance(obs, dict): + raise ValueError("Invalid request payload") + return obs + + async def handle_connection(self) -> None: + """Main loop for OpenPI-compatible policy serving.""" + await self.websocket.accept() + + try: + # Send model-specific PolicyServerConfig resolved by serving from + # diffusion od_config.model_config. + metadata = self.serving.policy_server_config.to_dict() + await self.websocket.send_bytes(_pack(metadata)) + + while True: + try: + msg = await asyncio.wait_for( + self.websocket.receive(), + timeout=self._idle_timeout, + ) + except asyncio.TimeoutError: + logger.info("Robot OpenPI connection idle timeout after %.1f seconds", self._idle_timeout) + try: + await self.websocket.close() + except Exception: + logger.debug("Failed to close idle robot OpenPI websocket", exc_info=True) + return + + if msg.get("type") == "websocket.disconnect": + break + + if "bytes" not in msg or not msg["bytes"]: + continue + + try: + obs = self._unpack_request(msg["bytes"]) + except Exception: + logger.exception("Invalid robot OpenPI request payload") + try: + await self._send_error("Invalid request payload") + except Exception: + break + continue + + try: + endpoint = obs.pop("endpoint", "infer") + + if endpoint == "reset": + self.reset() + self.serving.reset(obs) + await self.websocket.send_bytes(_pack({"status": "reset successful"})) + else: + session_id = str(obs.get("session_id") or self._current_session_id or "default") + if session_id != self._current_session_id: + if self._current_session_id is not None: + logger.info( + "Robot OpenPI session changed %s -> %s", + self._current_session_id, + session_id, + ) + self._current_session_id = session_id + self._call_count = 0 + + self._call_count += 1 + actions = await self.serving.infer( + obs, + session_id=session_id, + reset=self._call_count <= 1, + ) + await self.websocket.send_bytes(_pack(actions)) + except Exception: + logger.exception("Error handling request") + try: + await self._send_error("Internal inference error") + except Exception: + break + + except WebSocketDisconnect: + pass + except Exception: + logger.exception("Connection error") diff --git a/vllm_omni/entrypoints/openai/realtime/robot/openpi_serving.py b/vllm_omni/entrypoints/openai/realtime/robot/openpi_serving.py new file mode 100644 index 00000000000..46e46276a9c --- /dev/null +++ b/vllm_omni/entrypoints/openai/realtime/robot/openpi_serving.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Serving layer for robot policy inference via `/v1/realtime/robot/openpi`. + +Flow: raw obs → engine request → actions. +The loaded policy model owns dataset transforms inside its pipeline. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from itertools import count +from typing import Any + +import numpy as np +from omegaconf import OmegaConf +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def _to_builtin_container(value: Any) -> Any: + if OmegaConf.is_config(value): + return OmegaConf.to_container(value, resolve=True) + if isinstance(value, Mapping): + return {key: _to_builtin_container(item) for key, item in value.items()} + if isinstance(value, (list, tuple)): + return [_to_builtin_container(item) for item in value] + return value + + +@dataclass(frozen=True) +class PolicyServerConfig: + """OpenPI policy server handshake config. + + Values are model-specific and must be provided by the loaded policy model. + """ + + values: dict[str, Any] + + @classmethod + def from_model_config(cls, model_config: Any) -> PolicyServerConfig: + if isinstance(model_config, Mapping): + raw_config = model_config.get("policy_server_config") + else: + raw_config = getattr(model_config, "policy_server_config", None) + + if raw_config is None: + raise ValueError("Robot OpenPI serving requires policy_server_config.") + if isinstance(raw_config, cls): + return raw_config + if not isinstance(raw_config, Mapping): + raise TypeError("policy_server_config must be a dict.") + return cls(_to_builtin_container(raw_config)) + + def to_dict(self) -> dict[str, Any]: + return _to_builtin_container(self.values) + + +class ServingRealtimeRobotOpenPI: + """Robot policy serving layer for OpenPI protocol. + + Model-specific transform/state lives in the diffusion pipeline. + """ + + def __init__( + self, + engine_client: Any, + model_name: str | None = None, + ) -> None: + self.engine_client = engine_client + self.model_name = model_name + self.policy_server_config = self._get_policy_server_config(engine_client) + self._request_counter = count() + + @classmethod + def create_policy_server( + cls, + engine_client: Any, + model_name: str | None = None, + ) -> ServingRealtimeRobotOpenPI | None: + try: + return cls(engine_client=engine_client, model_name=model_name) + except ValueError as exc: + if "policy_server_config" not in str(exc): + raise + logger.info("Robot OpenPI serving disabled for model %s", model_name) + return None + + @staticmethod + def _get_policy_server_config(engine_client: Any) -> PolicyServerConfig: + model_config = None + get_od_config = getattr(engine_client, "get_diffusion_od_config", None) + if callable(get_od_config): + od_config = get_od_config() + model_config = getattr(od_config, "model_config", None) + + if model_config is None: + for stage_config in getattr(engine_client, "stage_configs", []) or []: + if getattr(stage_config, "stage_type", None) != "diffusion": + continue + engine_args = getattr(stage_config, "engine_args", None) + model_config = getattr(engine_args, "model_config", None) + if model_config is not None: + break + + if model_config is None: + od_config = getattr(engine_client, "od_config", None) + model_config = getattr(od_config, "model_config", None) + + if model_config is None: + model_config = getattr(engine_client, "model_config", None) + return PolicyServerConfig.from_model_config(model_config) + + def reset(self, obs: dict) -> None: + """Compatibility hook; per-connection state lives in RobotRealtimeConnection.""" + + async def infer(self, obs: dict, *, session_id: str, reset: bool) -> np.ndarray: + """raw obs → engine → actions.""" + # Build request, run inference through AsyncOmni + request = self._build_request(obs, session_id=session_id, reset=reset) + result = None + # OpenPI policy serving is one request -> one action reply. AsyncOmni + # exposes an async iterator, so consume it to completion and use the + # final output, matching other non-streaming OpenAI serving paths. + async for output in self.engine_client.generate( + prompt=request.prompts[0], + request_id=request.request_ids[0], + sampling_params_list=[request.sampling_params], + ): + result = output + if result is None: + raise RuntimeError("Robot OpenPI request produced no output.") + + return self._extract_actions(result) + + def _next_request_id(self, session_id: str) -> str: + """Return a unique engine request id while keeping session_id stateful.""" + return f"robot-{session_id}-{next(self._request_counter)}" + + def _build_request(self, obs: dict, *, session_id: str, reset: bool) -> Any: + """Build engine request from raw robot obs. + + Returns an `OmniDiffusionRequest` payload consumed by + `AsyncOmni.generate()` and routed to the diffusion stage. + """ + from vllm_omni.diffusion.request import OmniDiffusionRequest + from vllm_omni.inputs.data import OmniDiffusionSamplingParams + + extra_args = { + "reset": reset, + "session_id": session_id, + "robot_obs": obs, + } + + prompt = obs.get("prompt", "") + sampling_params = OmniDiffusionSamplingParams(extra_args=extra_args) + return OmniDiffusionRequest( + prompts=[prompt], + sampling_params=sampling_params, + request_ids=[self._next_request_id(session_id)], + ) + + def _extract_actions(self, result: Any) -> np.ndarray: + """Extract actions from engine result.""" + if hasattr(result, "__iter__"): + result = list(result) + if result: + result = result[0] + + if not hasattr(result, "multimodal_output") or result.multimodal_output is None: + raise RuntimeError("Missing multimodal_output in robot policy result") + + actions = result.multimodal_output.get("actions") + if actions is None: + raise RuntimeError("Missing multimodal_output['actions'] in robot policy result") + return np.asarray(actions, dtype=np.float32) diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index de7eb5f4c7e..578d7c04a21 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -14,6 +14,7 @@ from vllm_omni.config.stage_config import StageConfigFactory from vllm_omni.config.yaml_util import create_config, load_yaml_config, merge_configs +from vllm_omni.diffusion.utils.hf_utils import _looks_like_dreamzero from vllm_omni.entrypoints.stage_utils import _to_dict from vllm_omni.inputs.data import OmniSamplingParams from vllm_omni.platforms import current_omni_platform @@ -336,6 +337,9 @@ def resolve_model_config_path(model: str) -> str: ) default_config_path = current_omni_platform.get_default_stage_config_path() + if model_type == "vla" and _looks_like_dreamzero(model): + model_type = "dreamzero" + if model_type in _DIFFUSERS_CLASS_TO_CONFIG: normalized_model_type = _DIFFUSERS_CLASS_TO_CONFIG[model_type] else: diff --git a/vllm_omni/model_executor/models/dreamzero/__init__.py b/vllm_omni/model_executor/models/dreamzero/__init__.py new file mode 100644 index 00000000000..0897383fc8d --- /dev/null +++ b/vllm_omni/model_executor/models/dreamzero/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.model_executor.models.dreamzero.pipeline import DREAMZERO_PIPELINE + +__all__ = ["DREAMZERO_PIPELINE"] diff --git a/vllm_omni/model_executor/models/dreamzero/pipeline.py b/vllm_omni/model_executor/models/dreamzero/pipeline.py new file mode 100644 index 00000000000..bc815dedcf7 --- /dev/null +++ b/vllm_omni/model_executor/models/dreamzero/pipeline.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""DreamZero single-stage diffusion topology.""" + +from vllm_omni.config.stage_config import ( + PipelineConfig, + StageExecutionType, + StagePipelineConfig, +) + +DREAMZERO_PIPELINE = PipelineConfig( + model_type="dreamzero", + model_arch="DreamZeroPipeline", + stages=( + StagePipelineConfig( + stage_id=0, + model_stage="diffusion", + execution_type=StageExecutionType.DIFFUSION, + input_sources=(), + final_output=True, + final_output_type="image", + model_arch="DreamZeroPipeline", + ), + ), +)