diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index f58123317de..3b1e60b47ae 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -22,6 +22,7 @@ th { | `MingFlashOmniForConditionalGeneration` + `MingImagePipeline` | Ming-flash-omni-2.0 (omni-speech + imagegen1) | `Jonathan1909/Ming-flash-omni-2.0` | ✅︎ | | | | | `BagelForConditionalGeneration` | BAGEL (DiT-only) | `ByteDance-Seed/BAGEL-7B-MoT` | ✅︎ | ✅︎ | | ✅︎ | | `InternVLAA1Pipeline` | InternVLA-A1 | `InternRobotics/InternVLA-A1-3B` | ✅︎ | ✅︎ | | | +| `Gr00tN1d7Pipeline` | GR00T N1.7 | `nvidia/GR00T-N1.7-3B` | ✅︎ | | | | | `HunyuanImage3ForCausalMM` | HunyuanImage3.0 (DiT-only) | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `QwenImagePipeline` | Qwen-Image | `Qwen/Qwen-Image` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `QwenImagePipeline` | Qwen-Image-2512 | `Qwen/Qwen-Image-2512` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | diff --git a/recipes/NVIDIA/GR00T-N1.7.md b/recipes/NVIDIA/GR00T-N1.7.md new file mode 100644 index 00000000000..0a314954026 --- /dev/null +++ b/recipes/NVIDIA/GR00T-N1.7.md @@ -0,0 +1,86 @@ +# GR00T-N1.7 + +> NVIDIA Isaac GR00T-N1.7-3B robot VLA policy served over the OpenPI WebSocket protocol + +## Summary + +- Vendor: NVIDIA +- Model: `nvidia/GR00T-N1.7-3B` +- Task: Vision-Language-Action (VLA) inference for robot manipulation +- Mode: Online serving via OpenPI WebSocket endpoint +- Maintainer: timzsu + +## When to use this recipe + +Use this recipe when you need to serve GR00T-N1.7 as a real-time robot policy +over the OpenPI WebSocket API. It configures the DROID embodiment +(`OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT`) and exposes the standard DROID action +keys (`eef_9d`, `gripper_position`, `joint_position`) with action horizon 40. + +## References + +- Upstream model: +- Upstream codebase: +- OpenPI client library: +- Pipeline: `vllm_omni.diffusion.models.gr00t.pipeline_gr00t.Gr00tN1d7Pipeline` +- Deploy config: [`vllm_omni/deploy/Gr00tN1d7.yaml`](../../vllm_omni/deploy/Gr00tN1d7.yaml) +- E2E test: [`tests/e2e/online_serving/test_gr00t_openpi.py`](../../tests/e2e/online_serving/test_gr00t_openpi.py) + +## Environment + +- OS: Linux +- Python: 3.11+ +- Driver / runtime: NVIDIA CUDA +- vLLM-Omni version or commit: use versions from your current checkout + +## Start server + +From repository root: + +```bash +vllm serve nvidia/GR00T-N1.7-3B \ + --omni \ + --host 127.0.0.1 \ + --port 8000 \ + --served-model-name gr00t-n1d7 \ + --stage-configs-path vllm_omni/deploy/Gr00tN1d7.yaml +``` + +Notes: + +- Only `max_num_seqs: 1` is supported (configured in the deploy YAML); GR00T + policy state is per-session and not designed for concurrent batching. +- The WebSocket endpoint is `ws://127.0.0.1:8000/v1/realtime/robot/openpi`. + The server handshake message (first frame after connect) is a msgpack-encoded + dict with `action_horizon`, `action_keys`, `embodiment_tag`, and + `needs_session_id`. + +## Verification + +```python +from tests.gr00t.openpi_client_helper import run_policy_session, validate_session_result +validate_session_result(run_policy_session(host="127.0.0.1", port=8000)) +``` + +Or run the e2e test suite: + +```bash +python -m pytest tests/e2e/online_serving/test_gr00t_openpi.py -v +``` + +The test sends a synthetic two-frame DROID observation and checks: + +- GR00T metadata contract: `image_resolution`, `action_horizon`, `action_keys`, `embodiment_tag` +- Action shapes: `eef_9d (1,40,9)`, `gripper_position (1,40,1)`, `joint_position (1,40,7)` +- All action values are finite float32 +- Reset response is `"reset successful"` + +## Notes + +- To switch embodiment, edit `embodiment_tag` under both `model_config` and + `policy_server_config` in `vllm_omni/deploy/Gr00tN1d7.yaml`. Supported values: + `OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` (default), `XDOF`, `XDOF_SUBTASK`, + `REAL_G1`, `REAL_R1_PRO_SHARPA`, `LIBERO_PANDA`, `SIMPLER_ENV_GOOGLE`, + `SIMPLER_ENV_WIDOWX`. +- GR00T weights are loaded directly by `Gr00tPolicy` via `AutoModel.from_pretrained`; + the pipeline's `load_weights` is intentionally a no-op. diff --git a/tests/diffusion/models/gr00t/__init__.py b/tests/diffusion/models/gr00t/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/diffusion/models/gr00t/test_pipeline.py b/tests/diffusion/models/gr00t/test_pipeline.py new file mode 100644 index 00000000000..dd75773ec3f --- /dev/null +++ b/tests/diffusion/models/gr00t/test_pipeline.py @@ -0,0 +1,142 @@ +from types import SimpleNamespace + +import numpy as np +import pytest + +from vllm_omni.diffusion.models.gr00t import pipeline_gr00t +from vllm_omni.diffusion.models.gr00t.pipeline_gr00t import Gr00tN1d7Pipeline +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class FakeGr00tPolicy: + instances = [] + + def __init__(self, **kwargs): + self.kwargs = kwargs + self.reset_calls = 0 + self.seen_obs = None + self.embodiment_tag = SimpleNamespace(value="fake_embodiment") + self.language_key = "annotation.language.language_instruction" + self.modality_configs = { + "action": SimpleNamespace( + delta_indices=[0, 1], + modality_keys=["arm", "gripper"], + ) + } + self.processor = SimpleNamespace( + state_action_processor=SimpleNamespace( + norm_params={ + "fake_embodiment": { + "action": { + "arm": {"dim": np.array(2)}, + "gripper": {"dim": np.array(1)}, + } + } + } + ) + ) + FakeGr00tPolicy.instances.append(self) + + def get_action(self, obs): + self.seen_obs = obs + return { + "arm": np.array([[[1.0, 2.0]]], dtype=np.float64), + "gripper": [[[3.0]]], + }, {"latency_ms": 1.0} + + def reset(self): + self.reset_calls += 1 + return {"reset": True} + + +@pytest.fixture(autouse=True) +def fake_gr00t_policy(monkeypatch): + FakeGr00tPolicy.instances.clear() + monkeypatch.setattr(pipeline_gr00t, "Gr00tPolicy", FakeGr00tPolicy) + + +def _pipeline(): + od_config = SimpleNamespace( + model="nvidia/GR00T-N1.7-3B", + model_config={ + "embodiment_tag": "LIBERO_PANDA", + "strict": False, + }, + custom_pipeline_args={}, + ) + return Gr00tN1d7Pipeline(od_config=od_config) + + +def test_pipeline_initializes_local_policy(): + pipeline = _pipeline() + + policy = FakeGr00tPolicy.instances[0] + assert policy.kwargs["model_path"] == "nvidia/GR00T-N1.7-3B" + assert policy.kwargs["embodiment_tag"] == "LIBERO_PANDA" + assert policy.kwargs["strict"] is False + assert pipeline.weights_sources == () + assert pipeline.load_weights(iter(())) == set() + + +def test_forward_returns_dict_actions_in_multimodal_output(): + pipeline = _pipeline() + req = OmniDiffusionRequest( + prompts=["pick"], + request_ids=["req"], + sampling_params=OmniDiffusionSamplingParams( + extra_args={ + "robot_obs": { + "images": {"cam": np.zeros((1, 1, 8, 8, 3), dtype=np.uint8)}, + "state": {"joint": np.zeros((1, 1, 2), dtype=np.float32)}, + "prompt": "pick the cube", + "session_id": "session-a", + }, + "reset": True, + } + ), + ) + + output = pipeline.forward(req) + + assert output.error is None + assert output.output is None + actions = output.multimodal_output["actions"] + assert set(actions) == {"arm", "gripper"} + assert actions["arm"].dtype == np.float32 + np.testing.assert_allclose(actions["arm"], np.array([[[1.0, 2.0]]], dtype=np.float32)) + policy = FakeGr00tPolicy.instances[0] + assert "video" in policy.seen_obs + assert policy.seen_obs["language"] == {"annotation.language.language_instruction": [["pick the cube"]]} + assert "images" not in policy.seen_obs + assert "prompt" not in policy.seen_obs + assert "session_id" not in policy.seen_obs + assert policy.reset_calls == 1 + + +def test_dummy_warmup_returns_shape_correct_zero_actions(): + pipeline = _pipeline() + req = OmniDiffusionRequest( + prompts=["dummy run"], + request_ids=["dummy_req_id"], + sampling_params=OmniDiffusionSamplingParams(num_inference_steps=1), + ) + + output = pipeline.forward(req) + + assert output.error is None + actions = output.multimodal_output["actions"] + assert set(actions) == {"arm", "gripper"} + assert actions["arm"].shape == (1, 2, 2) + assert actions["gripper"].shape == (1, 2, 1) + assert not actions["arm"].any() + assert FakeGr00tPolicy.instances[0].seen_obs is None + + +def test_reset_delegates_to_policy(): + pipeline = _pipeline() + + assert pipeline.reset() == {"reset": True} + assert FakeGr00tPolicy.instances[0].reset_calls == 1 diff --git a/tests/e2e/online_serving/test_gr00t_openpi.py b/tests/e2e/online_serving/test_gr00t_openpi.py new file mode 100644 index 00000000000..51f1a386cca --- /dev/null +++ b/tests/e2e/online_serving/test_gr00t_openpi.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""End-to-end online serving test for GR00T N1.7 through the OpenPI robot endpoint.""" + +import os + +import numpy as np +import pytest + +from tests.gr00t import openpi_client_helper as openpi_client +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServerParams +from tests.helpers.stage_config import get_deploy_config_path + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +MODEL = "nvidia/GR00T-N1.7-3B" +openpi_client.require_dependencies() + +test_params = [ + pytest.param( + OmniServerParams( + model=MODEL, + stage_config_path=get_deploy_config_path("Gr00tN1d7.yaml"), + server_args=["--disable-log-stats"], + env_dict={"VLLM_DISABLE_COMPILE_CACHE": "1", "GR00T_NOISE_SEED": "42"}, + init_timeout=1200, + stage_init_timeout=900, + ), + id="gr00t-n1d7-openpi", + ) +] + +# Reference values captured from the Isaac-GR00T ZMQ server (nvidia/GR00T-N1.7-3B, +# OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT embodiment) using build_droid_observation() +# inputs: zero images (256×256), identity eef_9d, zero gripper/joint, "pick up the object". +# Outputs are bit-reproducible across resets and across runs (max_diff=0.0). +_REF_EEF_9D = np.array( + [ + [ + 0.014888550154864788, + -0.00039258378092199564, + -0.013574761338531971, + 0.9999837875366211, + 0.005678884219378233, + -0.00034708858584053814, + -0.005677700508385897, + 0.9999783635139465, + 0.0033210734836757183, + ], + [ + 0.020188070833683014, + -0.0003105341165792197, + -0.02232760563492775, + 0.9997346997261047, + 0.02301345206797123, + 0.000983047066256404, + -0.02302766777575016, + 0.9995648860931396, + 0.018431292846798897, + ], + [ + -0.007266733795404434, + -0.05537768080830574, + 0.03667901083827019, + 0.992686927318573, + 0.1206541359424591, + 0.003906540106981993, + -0.12024091929197311, + 0.9853788614273071, + 0.12070894986391068, + ], + ], + dtype=np.float32, +) # rows = step 0, step 4, step 39 + +_REF_GRIPPER = np.array([[0.0], [0.0078125], [0.939453125]], dtype=np.float32) # steps 0, 4, 39 + +_REF_JOINT = np.array( + [ + [ + -0.0010484338272362947, + 0.0014262489276006818, + -0.003565810853615403, + -3.846167237497866e-05, + -0.0002604846959002316, + 0.008521700277924538, + -0.006872728932648897, + ], + [ + -0.009435676969587803, + 0.0021475711837410927, + -0.0031688229646533728, + -1.8328893929719925e-05, + 0.0005945992306806147, + 0.019159257411956787, + 0.0009468861389905214, + ], + [ + -0.02944088727235794, + -0.08419207483530045, + -0.0251418836414814, + 0.00540524534881115, + 0.04752273112535477, + -0.012884500436484814, + 0.024298785254359245, + ], + ], + dtype=np.float32, +) # rows = step 0, step 4, step 39 + + +@pytest.mark.advanced_model +@pytest.mark.diffusion +@hardware_test(res={"cuda": "H100"}, num_cards=1) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_gr00t_n1d7_openpi_online(omni_server) -> None: + result = openpi_client.run_policy_session( + host=omni_server.host, + port=omni_server.port, + session_id="gr00t-online-e2e", + ) + openpi_client.validate_session_result(result) + + +@pytest.mark.advanced_model +@pytest.mark.diffusion +@hardware_test(res={"cuda": "H100"}, num_cards=1) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_gr00t_n1d7_openpi_precision(omni_server) -> None: + """Assert actions match Isaac-GR00T reference (GR00T_NOISE_SEED=42, zero inputs).""" + client = openpi_client.OpenPIWebsocketClient(host=omni_server.host, port=omni_server.port) + try: + obs = openpi_client.build_droid_observation(session_id="gr00t-precision-e2e") + client.reset({}) + actions = client.infer(obs) + finally: + client.close() + + steps = [0, 4, 39] + np.testing.assert_allclose( + actions["eef_9d"][0, steps, :], + _REF_EEF_9D, + atol=1e-2, + rtol=0.0, + err_msg="eef_9d action mismatch vs Isaac-GR00T reference", + ) + np.testing.assert_allclose( + actions["gripper_position"][0, steps, :], + _REF_GRIPPER, + atol=1e-2, + rtol=0.0, + err_msg="gripper_position action mismatch vs Isaac-GR00T reference", + ) + np.testing.assert_allclose( + actions["joint_position"][0, steps, :], + _REF_JOINT, + atol=1e-2, + rtol=0.0, + err_msg="joint_position action mismatch vs Isaac-GR00T reference", + ) diff --git a/tests/gr00t/__init__.py b/tests/gr00t/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/tests/gr00t/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/tests/gr00t/openpi_client_helper.py b/tests/gr00t/openpi_client_helper.py new file mode 100644 index 00000000000..e5943ca3307 --- /dev/null +++ b/tests/gr00t/openpi_client_helper.py @@ -0,0 +1,183 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import Any + +import numpy as np + +try: + import websockets.sync.client as websockets_client +except ImportError: # pragma: no cover - optional e2e dependency + websockets_client = None + +try: + from openpi_client import msgpack_numpy +except ImportError: # pragma: no cover - optional e2e dependency + msgpack_numpy = None + +PING_INTERVAL_SECS = 300 +PING_TIMEOUT_SECS = 3600 +DEFAULT_HOST = "127.0.0.1" +DEFAULT_PORT = 8000 +DEFAULT_PATH = "/v1/realtime/robot/openpi" +DEFAULT_SESSION_ID = "gr00t-smoke" +ACTION_KEYS = {"eef_9d", "gripper_position", "joint_position"} +LANGUAGE_KEY = "annotation.language.language_instruction" + + +def _identity_eef_9d_state() -> np.ndarray: + state = np.zeros((1, 1, 9), dtype=np.float32) + state[..., 3:] = np.array([1, 0, 0, 0, 1, 0], dtype=np.float32) + return state + + +def require_dependencies() -> None: + missing = [] + if websockets_client is None: + missing.append("websockets") + if msgpack_numpy is None: + missing.append("openpi-client") + if missing: + raise ModuleNotFoundError(f"GR00T OpenPI test dependencies are missing: {', '.join(missing)}") + + +@dataclass(frozen=True) +class Gr00tServerMetadata: + action_horizon: int + action_keys: set[str] + embodiment_tag: str + needs_session_id: bool + + @classmethod + def from_dict(cls, payload: dict[str, Any]): + required_keys = ("action_horizon", "action_keys", "embodiment_tag", "needs_session_id") + missing_keys = [key for key in required_keys if key not in payload] + if missing_keys: + raise ValueError(f"Missing GR00T metadata keys: {missing_keys}") + + return cls( + action_horizon=int(payload["action_horizon"]), + action_keys={str(key) for key in payload["action_keys"]}, + embodiment_tag=str(payload["embodiment_tag"]), + needs_session_id=bool(payload["needs_session_id"]), + ) + + +class OpenPIWebsocketClient: + def __init__( + self, + *, + host: str = DEFAULT_HOST, + port: int = DEFAULT_PORT, + path: str = DEFAULT_PATH, + ) -> None: + require_dependencies() + self._uri = f"ws://{host}:{port}{path}" + self._packer = msgpack_numpy.Packer() + self._ws, self._server_metadata = self._connect() + + def _connect(self): + conn = websockets_client.connect( + self._uri, + compression=None, + max_size=None, + ping_interval=PING_INTERVAL_SECS, + ping_timeout=PING_TIMEOUT_SECS, + ) + metadata = msgpack_numpy.unpackb(conn.recv()) + if not isinstance(metadata, dict): + raise TypeError(f"Expected dict metadata from server, got {type(metadata)!r}") + return conn, metadata + + def get_server_metadata(self) -> dict[str, Any]: + return dict(self._server_metadata) + + def infer(self, obs: dict[str, Any]) -> dict[str, np.ndarray]: + payload = dict(obs) + payload["endpoint"] = "infer" + self._ws.send(self._packer.pack(payload)) + response = msgpack_numpy.unpackb(self._ws.recv()) + if isinstance(response, dict) and response.get("type") == "error": + raise RuntimeError(f"Inference failed: {response['message']}") + if not isinstance(response, dict): + raise RuntimeError(f"Expected dict actions from GR00T OpenPI endpoint, got {type(response)!r}") + return {str(key): np.asarray(value, dtype=np.float32) for key, value in response.items()} + + def reset(self, reset_info: dict[str, Any] | None = None) -> str: + payload = dict(reset_info or {}) + payload["endpoint"] = "reset" + self._ws.send(self._packer.pack(payload)) + response = msgpack_numpy.unpackb(self._ws.recv()) + if not isinstance(response, dict) or response.get("status") != "reset successful": + raise RuntimeError(f"Unexpected reset response: {response!r}") + return str(response["status"]) + + def close(self) -> None: + self._ws.close() + + +def build_droid_observation(*, session_id: str = DEFAULT_SESSION_ID) -> dict[str, Any]: + return { + "session_id": session_id, + "video": { + "exterior_image_1_left": np.zeros((1, 2, 256, 256, 3), dtype=np.uint8), + "wrist_image_left": np.zeros((1, 2, 256, 256, 3), dtype=np.uint8), + }, + "state": { + "eef_9d": _identity_eef_9d_state(), + "gripper_position": np.zeros((1, 1, 1), dtype=np.float32), + "joint_position": np.zeros((1, 1, 7), dtype=np.float32), + }, + "language": {LANGUAGE_KEY: [["pick up the object"]]}, + } + + +def run_policy_session( + *, + host: str = DEFAULT_HOST, + port: int = DEFAULT_PORT, + path: str = DEFAULT_PATH, + session_id: str = DEFAULT_SESSION_ID, +) -> dict[str, Any]: + client = OpenPIWebsocketClient(host=host, port=port, path=path) + try: + metadata = client.get_server_metadata() + actions = client.infer(build_droid_observation(session_id=session_id)) + reset_status = client.reset({}) + return { + "metadata": metadata, + "actions": actions, + "reset_status": reset_status, + "session_id": session_id, + } + finally: + client.close() + + +def validate_session_result(result: dict[str, Any]) -> None: + metadata = Gr00tServerMetadata.from_dict(result["metadata"]) + if metadata.embodiment_tag != "OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT": + raise AssertionError(f"Unexpected embodiment_tag: {metadata.embodiment_tag}") + if not metadata.needs_session_id: + raise AssertionError("GR00T test expects needs_session_id metadata") + if metadata.action_keys != ACTION_KEYS: + raise AssertionError(f"Unexpected action keys: {metadata.action_keys}") + if result["reset_status"] != "reset successful": + raise AssertionError(f"Unexpected reset status: {result['reset_status']!r}") + + actions = result["actions"] + if set(actions) != ACTION_KEYS: + raise AssertionError(f"Unexpected action keys: {set(actions)}") + expected_shapes = { + "eef_9d": (1, metadata.action_horizon, 9), + "gripper_position": (1, metadata.action_horizon, 1), + "joint_position": (1, metadata.action_horizon, 7), + } + for key, expected_shape in expected_shapes.items(): + if actions[key].shape != expected_shape: + raise AssertionError(f"Action {key} shape mismatch: expected {expected_shape}, got {actions[key].shape}") + if actions[key].dtype != np.float32: + raise AssertionError(f"Action {key} dtype mismatch: expected float32, got {actions[key].dtype}") + if not np.isfinite(actions[key]).all(): + raise AssertionError(f"Action {key} contains non-finite values") diff --git a/vllm_omni/config/pipeline_registry.py b/vllm_omni/config/pipeline_registry.py index 140bddf45ae..c375cc38d60 100644 --- a/vllm_omni/config/pipeline_registry.py +++ b/vllm_omni/config/pipeline_registry.py @@ -28,8 +28,6 @@ the entries declared here. """ -from __future__ import annotations - # --- Multi-stage omni pipelines (LLM-centric; audio / video I/O) --- _OMNI_PIPELINES: dict[str, tuple[str, str]] = { # model_type -> (module_path, variable_name) @@ -69,6 +67,10 @@ "vllm_omni.model_executor.models.glm_image.pipeline", "GLM_IMAGE_PIPELINE", ), + "Gr00tN1d7": ( + "vllm_omni.model_executor.models.gr00t.pipeline", + "GR00T_N1D7_PIPELINE", + ), "hunyuan_image_3_moe": ( "vllm_omni.model_executor.models.hunyuan_image3.pipeline", "HUNYUAN_IMAGE3_PIPELINE", diff --git a/vllm_omni/deploy/Gr00tN1d7.yaml b/vllm_omni/deploy/Gr00tN1d7.yaml new file mode 100644 index 00000000000..8ebb28d5163 --- /dev/null +++ b/vllm_omni/deploy/Gr00tN1d7.yaml @@ -0,0 +1,40 @@ +# GR00T N1.7 deploy: single diffusion stage. +# +# Topology is declared in vllm_omni/model_executor/models/gr00t/pipeline.py. +# This default uses one GPU with TP=1 and CFG parallel disabled. + +pipeline: Gr00tN1d7 +async_chunk: false +distributed_executor_backend: mp +dtype: bfloat16 + +stages: + - stage_id: 0 + devices: "0" + max_num_seqs: 1 + enforce_eager: true + model_class_name: Gr00tN1d7Pipeline + model_config: + embodiment_tag: OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT + strict: true + policy_server_config: + image_resolution: [180, 320] + needs_session_id: true + embodiment_tag: OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT + action_horizon: 40 + action_keys: + - eef_9d + - gripper_position + - joint_position + supported_embodiments: + - OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT + - XDOF + - XDOF_SUBTASK + - REAL_G1 + - REAL_R1_PRO_SHARPA + - REAL_R1_PRO_SHARPA_HUMAN + - REAL_R1_PRO_SHARPA_MAXINSIGHTS + - REAL_R1_PRO_SHARPA_MECKA + - LIBERO_PANDA + - SIMPLER_ENV_GOOGLE + - SIMPLER_ENV_WIDOWX diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 5fe2d4061cc..01026f35ebd 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -935,6 +935,10 @@ def enrich_config(self) -> None: self.model_class_name = "WanS2VPipeline" self.tf_model_config = TransformerConfig() self.update_multimodal_support() + elif model_type == "Gr00tN1d7" or "Gr00tN1d7" in architectures: + self.model_class_name = "Gr00tN1d7Pipeline" + self.set_tf_model_config(TransformerConfig()) + self.update_multimodal_support() elif architectures and len(architectures) == 1: self.model_class_name = architectures[0] else: @@ -1006,10 +1010,10 @@ class DiffusionOutput: """ # Fields may be replaced with SHM handle dicts by ipc.pack_diffusion_output_shm - output: torch.Tensor | dict | None = None - trajectory_timesteps: torch.Tensor | dict | None = None - trajectory_latents: torch.Tensor | dict | None = None - trajectory_log_probs: torch.Tensor | dict | None = None + output: torch.Tensor | tuple[Any, ...] | dict[str, Any] | None = None + trajectory_timesteps: torch.Tensor | dict[str, Any] | None = None + trajectory_latents: torch.Tensor | dict[str, Any] | None = None + trajectory_log_probs: torch.Tensor | dict[str, Any] | None = None trajectory_decoded: list[Image.Image] | None = None error: str | None = None aborted: bool = False @@ -1017,6 +1021,11 @@ class DiffusionOutput: post_process_func: Callable[..., Any] | None = None + # Multimodal payloads that should be surfaced on OmniRequestOutput. + # Robot policies use this for `{"actions": ...}` without pretending the + # action dict is an image/video output. + multimodal_output: dict[str, Any] = field(default_factory=dict) + # Extra custom output data (e.g. latent trajectories, prompt embeds) # passed through to OmniRequestOutput.custom_output custom_output: dict[str, Any] = field(default_factory=dict) diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 1805b5dc349..d255f8b22d9 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import asyncio import concurrent.futures import inspect @@ -220,7 +218,8 @@ async def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: raise RuntimeError(output.error) logger.debug("Generation completed successfully.") - if output.output is None: + base_multimodal_output = output.multimodal_output + if output.output is None and not base_multimodal_output: logger.warning("Output is None, returning empty OmniRequestOutput") return [ OmniRequestOutput.from_diffusion( @@ -258,12 +257,18 @@ async def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: custom_output = output.custom_output or {} model_audio_sample_rate = None model_fps = None + non_visual_output_keys = {"audio", "actions", "custom_output", "audio_sample_rate", "fps"} if isinstance(outputs, dict): audio_payload = outputs.get("audio") + if "actions" in outputs: + base_multimodal_output["actions"] = outputs["actions"] custom_output.update(outputs.get("custom_output") or {}) model_audio_sample_rate = outputs.get("audio_sample_rate") model_fps = outputs.get("fps") - outputs = outputs.get("video", outputs) + if "video" in outputs: + outputs = outputs["video"] + elif outputs.keys() <= non_visual_output_keys: + outputs = None postprocess_time = time.perf_counter() - postprocess_start_time logger.debug("Post-processing completed in %.4f seconds", postprocess_time) @@ -346,13 +351,14 @@ def _audio_mm(payload: Any) -> dict[str, Any]: ), ] else: - mm_output = {} + mm_output = base_multimodal_output.copy() if audio_payload is not None: mm_output["audio"] = audio_payload if model_audio_sample_rate is not None: mm_output["audio_sample_rate"] = model_audio_sample_rate if model_fps is not None: mm_output["fps"] = model_fps + final_output_type = "actions" if "actions" in mm_output and not outputs else "image" return [ OmniRequestOutput.from_diffusion( request_id=request_id, @@ -366,6 +372,7 @@ def _audio_mm(payload: Any) -> dict[str, Any]: trajectory_decoded=output.trajectory_decoded, custom_output=custom_output, multimodal_output=mm_output, + final_output_type=final_output_type, stage_durations=output.stage_durations, peak_memory_mb=output.peak_memory_mb, ), @@ -405,7 +412,7 @@ def _audio_mm(payload: Any) -> dict[str, Any]: ), ) else: - mm_output = {} + mm_output = base_multimodal_output.copy() if audio_payload is not None: sliced_audio = audio_payload if isinstance(audio_payload, (list, tuple)): @@ -422,6 +429,7 @@ def _audio_mm(payload: Any) -> dict[str, Any]: mm_output["audio_sample_rate"] = model_audio_sample_rate if model_fps is not None: mm_output["fps"] = model_fps + final_output_type = "actions" if "actions" in mm_output and not request_outputs else "image" results.append( OmniRequestOutput.from_diffusion( request_id=request_id, @@ -435,6 +443,7 @@ def _audio_mm(payload: Any) -> dict[str, Any]: trajectory_decoded=output.trajectory_decoded, custom_output=custom_output, multimodal_output=mm_output, + final_output_type=final_output_type, stage_durations=output.stage_durations, peak_memory_mb=output.peak_memory_mb, ), @@ -569,7 +578,7 @@ def _handle_finished_requests( def make_engine( config: OmniDiffusionConfig, scheduler: SchedulerInterface | None = None, - ) -> DiffusionEngine: + ) -> "DiffusionEngine": """Factory method to create a DiffusionEngine instance. Args: diff --git a/vllm_omni/diffusion/models/gr00t/__init__.py b/vllm_omni/diffusion/models/gr00t/__init__.py new file mode 100644 index 00000000000..a10202a40bc --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.diffusion.models.gr00t.pipeline_gr00t import Gr00tN1d7Pipeline + +__all__ = ["Gr00tN1d7Pipeline"] diff --git a/vllm_omni/diffusion/models/gr00t/configs/__init__.py b/vllm_omni/diffusion/models/gr00t/configs/__init__.py new file mode 100644 index 00000000000..9648b859751 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/configs/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +MODEL_CONFIG_TYPES: dict[str, type] = {} + + +def register_model_config(shortname: str, configtype: type) -> None: + MODEL_CONFIG_TYPES[shortname] = configtype diff --git a/vllm_omni/diffusion/models/gr00t/configs/embodiment_configs.py b/vllm_omni/diffusion/models/gr00t/configs/embodiment_configs.py new file mode 100644 index 00000000000..7d11a651cf6 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/configs/embodiment_configs.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.diffusion.models.gr00t.dataio.types import ( + ActionConfig, + ActionFormat, + ActionRepresentation, + ActionType, + ModalityConfig, +) + +__all__ = ["ActionConfig", "ActionFormat", "ActionRepresentation", "ActionType", "ModalityConfig"] diff --git a/vllm_omni/diffusion/models/gr00t/configs/gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/configs/gr00t_n1d7.py new file mode 100644 index 00000000000..d7d4527fb93 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/configs/gr00t_n1d7.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from transformers import PretrainedConfig + +from . import register_model_config + + +def _default_diffusion_model_cfg() -> dict: + return { + "positional_embeddings": None, + "num_layers": 16, + "num_attention_heads": 32, + "attention_head_dim": 48, + "norm_type": "ada_norm", + "output_dim": 1024, + "interleave_self_attention": True, + } + + +class Gr00tN1d7Config(PretrainedConfig): + model_type = "Gr00tN1d7" + + def __init__( + self, + model_dtype: str = "bfloat16", + model_name: str = "nvidia/Cosmos-Reason2-2B", + backbone_model_type: str = "qwen", + model_revision: str | None = None, + backbone_embedding_dim: int = 2048, + select_layer: int = 12, + reproject_vision: bool = False, + use_flash_attention: bool = True, + load_bf16: bool = False, + image_crop_size: tuple | None = (230, 230), + image_target_size: tuple | None = (256, 256), + shortest_image_edge: int | None = None, + crop_fraction: float | None = None, + random_rotation_angle: int | None = None, + color_jitter_params: dict | None = None, + formalize_language: bool = True, + apply_sincos_state_encoding: bool = False, + use_percentiles: bool = True, + use_relative_action: bool = False, + max_state_dim: int = 132, + max_action_dim: int = 132, + action_horizon: int = 40, + hidden_size: int = 1024, + input_embedding_dim: int = 1536, + state_history_length: int = 1, + add_pos_embed: bool = True, + attn_dropout: float = 0.2, + use_vlln: bool = True, + max_seq_len: int = 1024, + use_alternate_vl_dit: bool = True, + attend_text_every_n_blocks: int = 2, + diffusion_model_cfg: dict | None = None, + vl_self_attention_cfg: dict | None = None, + num_inference_timesteps: int = 4, + num_timestep_buckets: int = 1000, + exclude_state: bool = False, + use_mean_std: bool = False, + max_num_embodiments: int = 32, + **kwargs, + ): + super().__init__(**kwargs) + self.model_dtype = model_dtype + self.model_name = model_name + self.backbone_model_type = backbone_model_type + self.model_revision = model_revision + self.backbone_embedding_dim = backbone_embedding_dim + self.select_layer = select_layer + self.reproject_vision = reproject_vision + self.use_flash_attention = use_flash_attention + self.load_bf16 = load_bf16 + self.image_crop_size = image_crop_size + self.image_target_size = image_target_size + self.shortest_image_edge = shortest_image_edge + self.crop_fraction = crop_fraction + self.random_rotation_angle = random_rotation_angle + self.color_jitter_params = color_jitter_params + self.formalize_language = formalize_language + self.apply_sincos_state_encoding = apply_sincos_state_encoding + self.use_percentiles = use_percentiles + self.use_relative_action = use_relative_action + self.max_state_dim = max_state_dim + self.max_action_dim = max_action_dim + self.action_horizon = action_horizon + self.hidden_size = hidden_size + self.input_embedding_dim = input_embedding_dim + self.state_history_length = state_history_length + self.add_pos_embed = add_pos_embed + self.attn_dropout = attn_dropout + self.use_vlln = use_vlln + self.max_seq_len = max_seq_len + self.use_alternate_vl_dit = use_alternate_vl_dit + self.attend_text_every_n_blocks = attend_text_every_n_blocks + self.diffusion_model_cfg = ( + diffusion_model_cfg if diffusion_model_cfg is not None else _default_diffusion_model_cfg() + ) + self.vl_self_attention_cfg = vl_self_attention_cfg + self.num_inference_timesteps = num_inference_timesteps + self.num_timestep_buckets = num_timestep_buckets + self.exclude_state = exclude_state + self.use_mean_std = use_mean_std + self.max_num_embodiments = max_num_embodiments + + +register_model_config("Gr00tN1d7", Gr00tN1d7Config) diff --git a/vllm_omni/diffusion/models/gr00t/dataio/__init__.py b/vllm_omni/diffusion/models/gr00t/dataio/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/diffusion/models/gr00t/dataio/embodiment_tags.py b/vllm_omni/diffusion/models/gr00t/dataio/embodiment_tags.py new file mode 100755 index 00000000000..3ed0c9b0360 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/embodiment_tags.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from enum import Enum + +""" +Embodiment tags are used to identify the robot embodiment in the data. + +Naming convention: +_ + +If using multiple datasets, e.g. sim GR1 and real GR1, we can drop the dataset name and use only the robot name. +""" + + +class EmbodimentTag(Enum): + """Embodiment tags supported by the GR00T N1.7 checkpoint. + + Pretrain tags (baked into the base model nvidia/GR00T-N1.7-3B, inference-ready): + - OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT -> "oxe_droid_relative_eef_relative_joint" + - XDOF -> "xdof_relative_eef_relative_joint" + - XDOF_SUBTASK -> "xdof_relative_eef_relative_joint_subtask" + - REAL_G1 -> "real_g1_relative_eef_relative_joints" + - REAL_R1_PRO_SHARPA -> "real_r1_pro_sharpa_relative_eef" + - REAL_R1_PRO_SHARPA_HUMAN -> "real_r1_pro_sharpa_relative_eef_human" + - REAL_R1_PRO_SHARPA_MAXINSIGHTS -> "real_r1_pro_sharpa_relative_eef_maxinsights" + - REAL_R1_PRO_SHARPA_MECKA -> "real_r1_pro_sharpa_relative_eef_mecka" + + Pre-registered posttrain tags (require finetuned checkpoint): + - UNITREE_G1 -> "unitree_g1_full_body_with_waist_height_nav_cmd" + - UNITREE_G1_SONIC -> "unitree_g1_sonic" + - SIMPLER_ENV_GOOGLE -> "simpler_env_google" + - SIMPLER_ENV_WIDOWX -> "simpler_env_widowx" + - LIBERO_PANDA -> "libero_sim" + + Finetuning tag (for custom robots): + - NEW_EMBODIMENT -> "new_embodiment" + + Use ``EmbodimentTag.resolve(s)`` to look up a tag by name or value, + case-insensitively. + """ + + OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT = "oxe_droid_relative_eef_relative_joint" + """ + The Open-X-Embodiment DROID robot with relative EEF and relative joint position actions. + """ + + XDOF = "xdof_relative_eef_relative_joint" + """ + The generic X-DOF robot with relative EEF and relative joint position actions. + """ + + XDOF_SUBTASK = "xdof_relative_eef_relative_joint_subtask" + """ + The generic X-DOF robot (subtask variant). + """ + + REAL_G1 = "real_g1_relative_eef_relative_joints" + """ + Real-world Unitree G1 with relative EEF and relative joint actions. + """ + + REAL_R1_PRO_SHARPA = "real_r1_pro_sharpa_relative_eef" + """ + Real-world R1 Pro Sharpa with relative EEF actions. + """ + + REAL_R1_PRO_SHARPA_HUMAN = "real_r1_pro_sharpa_relative_eef_human" + """ + Real-world R1 Pro Sharpa with relative EEF actions (human teleop data). + """ + + REAL_R1_PRO_SHARPA_MAXINSIGHTS = "real_r1_pro_sharpa_relative_eef_maxinsights" + """ + Real-world R1 Pro Sharpa with relative EEF actions (MaxInsights data, single-cam). + """ + + REAL_R1_PRO_SHARPA_MECKA = "real_r1_pro_sharpa_relative_eef_mecka" + """ + Real-world R1 Pro Sharpa with relative EEF actions (Mecka data, single-cam). + """ + + UNITREE_G1 = "unitree_g1_full_body_with_waist_height_nav_cmd" + """ + The Unitree G1 robot (sim, full-body with waist height and nav commands). + """ + + UNITREE_G1_SONIC = "unitree_g1_sonic" + """ + The Unitree G1 robot with SONIC whole-body controller. VLA action space is SONIC latents. + """ + + SIMPLER_ENV_GOOGLE = "simpler_env_google" + """ + The SimplerEnv Google robot. + """ + + SIMPLER_ENV_WIDOWX = "simpler_env_widowx" + """ + The SimplerEnv WidowX robot. + """ + + LIBERO_PANDA = "libero_sim" + """ + The LIBERO Panda robot (used for LIBERO-Goal, LIBERO-Object, LIBERO-Spatial, LIBERO-10). + """ + + # New embodiment during post-training + NEW_EMBODIMENT = "new_embodiment" + """ + Any new embodiment. + """ + + @classmethod + def resolve(cls, tag: "str | EmbodimentTag") -> "EmbodimentTag": + """Resolve a string to an EmbodimentTag, case-insensitively. + + Matches by enum **name** first (e.g. ``"xdof"`` -> ``XDOF``), then by + enum **value** (e.g. ``"xdof_relative_eef_relative_joint"`` -> ``XDOF``). + + Raises: + ValueError: If *tag* does not match any known embodiment. + """ + if isinstance(tag, cls): + return tag + key = tag.strip() + key_lower = key.lower() + # Match by enum name (case-insensitive) + for member in cls: + if member.name.lower() == key_lower: + return member + # Match by enum value (case-insensitive) + for member in cls: + if member.value.lower() == key_lower: + return member + + def _fmt(tags): + return "\n".join(f" {m.name:40s} -> {m.value}" for m in tags) + + msg = ( + f"Unknown embodiment tag: {tag!r}\n\n" + f" Base model tags (work with nvidia/GR00T-N1.7-3B):\n" + f"{_fmt(PRETRAIN_TAGS)}\n\n" + f" Posttrain tags (require a finetuned checkpoint):\n" + f"{_fmt(POSTTRAIN_TAGS)}\n\n" + f" Finetuning-only tags (for custom robots):\n" + f"{_fmt(FINETUNE_ONLY_TAGS)}" + ) + raise ValueError(msg) + + @classmethod + def reverse_lookup(cls, value: str) -> "str": + """Map a tag value string back to its enum name, or return the value as-is.""" + for member in cls: + if member.value == value: + return member.name + return value + + +# Module-level tag category sets (cannot be Enum class attributes). +PRETRAIN_TAGS: frozenset[EmbodimentTag] = frozenset( + { + EmbodimentTag.OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT, + EmbodimentTag.XDOF, + EmbodimentTag.XDOF_SUBTASK, + EmbodimentTag.REAL_G1, + EmbodimentTag.REAL_R1_PRO_SHARPA, + EmbodimentTag.REAL_R1_PRO_SHARPA_HUMAN, + EmbodimentTag.REAL_R1_PRO_SHARPA_MAXINSIGHTS, + EmbodimentTag.REAL_R1_PRO_SHARPA_MECKA, + } +) +"""Tags baked into the base model (nvidia/GR00T-N1.7-3B) — usable without finetuning.""" + +POSTTRAIN_TAGS: frozenset[EmbodimentTag] = frozenset( + { + EmbodimentTag.UNITREE_G1, + EmbodimentTag.UNITREE_G1_SONIC, + EmbodimentTag.SIMPLER_ENV_GOOGLE, + EmbodimentTag.SIMPLER_ENV_WIDOWX, + EmbodimentTag.LIBERO_PANDA, + } +) +"""Tags that require a finetuned checkpoint.""" + +FINETUNE_ONLY_TAGS: frozenset[EmbodimentTag] = frozenset( + { + EmbodimentTag.NEW_EMBODIMENT, + } +) +"""Tags for custom robots (finetuning only, not in any shipped checkpoint).""" diff --git a/vllm_omni/diffusion/models/gr00t/dataio/state_action/__init__.py b/vllm_omni/diffusion/models/gr00t/dataio/state_action/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/state_action/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/diffusion/models/gr00t/dataio/state_action/action_chunking.py b/vllm_omni/diffusion/models/gr00t/dataio/state_action/action_chunking.py new file mode 100644 index 00000000000..4128d6cae6c --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/state_action/action_chunking.py @@ -0,0 +1,654 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Generic, TypeVar + +import numpy as np +from numpy.typing import NDArray +from scipy import interpolate +from scipy.spatial.transform import Rotation, Slerp + +from vllm_omni.diffusion.models.gr00t.dataio.state_action.pose import EndEffectorPose, JointPose, Pose +from vllm_omni.diffusion.models.gr00t.dataio.types import ActionFormat + +PoseType = TypeVar("PoseType", bound=Pose) + + +class ActionChunk(Generic[PoseType]): + """ + Abstract base class for robot action chunking. + + This class provides common functionality for different action chunking types + including relative and delta action chunking computation with optional reference frames, + interpolation, and format conversion. + """ + + def __init__( + self, + poses: Sequence[PoseType], + times: Sequence[float] | NDArray[np.float64] | None = None, + ): + """ + Initialize action chunking from a list of poses. + + Args: + poses: Sequence of Pose objects + times: Optional sequence of timestamps for each pose. If None, assumes + uniform spacing starting from 0 with step 1.0 + + Raises: + ValueError: If action chunking is empty or times length doesn't match poses + """ + if not poses: + raise ValueError("ActionChunk must contain at least one pose") + + self._poses: list[PoseType] = list(poses) + + # Set up times + if times is None: + self._times = np.arange(len(poses), dtype=np.float64) + else: + if len(times) != len(poses): + raise ValueError("Number of times must match number of poses") + self._times = np.array(times, dtype=np.float64) + + @property + def poses(self) -> list[PoseType]: + """Get the list of poses""" + return self._poses.copy() + + @property + def times(self) -> NDArray[np.float64]: + """Get the timestamps""" + return self._times.copy() + + @property + def num_poses(self) -> int: + """Get the number of poses in the action chunking""" + return len(self._poses) + + def relative_chunking(self, reference_frame: PoseType | None = None) -> "ActionChunk[PoseType]": + """ + Compute the relative action chunking with respect to a reference frame. + + If reference_frame is None, uses the first pose in the action chunking as reference. + All poses are transformed to be relative to the reference frame. + + Args: + reference_frame: Optional reference pose. If None, uses first pose. + + Returns: + A new ActionChunk of the same type where all poses are relative to the reference frame. + """ + if not self._poses: + return self.__class__([], times=[]) + + # Use the first pose as the reference if one is not provided. + ref_pose = reference_frame if reference_frame is not None else self._poses[0] + + # Use the polymorphic subtraction defined in the Pose subclasses. + # The subtraction returns the same type as the operands + relative_poses: list[PoseType] = [pose - ref_pose for pose in self._poses] # type: ignore[misc] + + # Return a new instance of the same action chunking class + # (e.g., JointActionChunk or EndEffectorActionChunk) + return self.__class__(relative_poses, times=self.times) + + def delta_chunking(self, reference_frame: PoseType | None = None) -> "ActionChunk[PoseType]": + """ + Compute the delta action chunking where each pose represents the relative + transformation from the previous frame. + + If reference_frame is provided, it is treated as the first frame, and the + first delta will be from reference_frame to the first pose in the action chunking. + Otherwise, the first pose in the delta action chunking will be the identity/zero transformation. + + Args: + reference_frame: Optional reference pose to use as the first frame. + + Returns: + A new ActionChunk of the same type where each pose is relative to the previous pose. + """ + if not self._poses: + return self.__class__([], times=[]) + + delta_poses: list[PoseType] = [] + + # Determine the initial reference for the very first pose. + # If a reference_frame is given, the first delta is pose[0] - reference_frame. + # If not, the first delta is pose[0] - pose[0], resulting in an identity/zero pose. + prev_pose = reference_frame if reference_frame is not None else self._poses[0] + + for current_pose in self._poses: + delta: PoseType = current_pose - prev_pose # type: ignore[assignment] + delta_poses.append(delta) + prev_pose = current_pose # Update the reference for the next step + + return self.__class__(delta_poses, times=self.times.tolist()) + + def to_absolute_chunking(self, reference_frame: PoseType) -> "ActionChunk[PoseType]": + """ + Convert a relative action chunking to an absolute action chunking by applying + the relative poses on top of a reference frame. + + This is the inverse operation of relative_chunking(). Each relative pose + is composed with the reference frame to produce absolute poses. + + Args: + reference_frame: The reference pose to apply the relative action chunking on top of. + + Returns: + A new ActionChunk of the same type with absolute poses. + + Raises: + NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError("Subclasses must implement to_absolute_chunking") + + def interpolate( + self, + num_points: int | None = None, + times: NDArray[np.float64] | None = None, + ) -> "ActionChunk": + """ + Interpolate the action chunking to generate intermediate poses. + Must be implemented by subclasses. + + Args: + num_points: Number of evenly-spaced points to generate + times: Specific timestamps at which to interpolate + + Returns: + A new ActionChunk with interpolated poses + """ + raise NotImplementedError("Subclasses must implement interpolate") + + def to(self, action_format: ActionFormat) -> NDArray[np.float64]: + """ + Convert action chunking to the specified action format. + Must be implemented by subclasses. + + Args: + action_format: The desired output format + + Returns: + Array in the requested format + + Raises: + NotImplementedError: If not implemented by subclass + """ + raise NotImplementedError("Subclasses must implement to method") + + def __len__(self) -> int: + """Return the number of poses in the action chunking""" + return len(self._poses) + + def __getitem__(self, index: int) -> PoseType: + """Get a pose by index""" + return self._poses[index] + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(num_poses={len(self._poses)}, " + f"time_range=[{self._times[0]:.2f}, {self._times[-1]:.2f}])" + ) + + +class JointActionChunk(ActionChunk[JointPose]): + """ + Represents action chunking in joint space as a sequence of joint configurations. + + Examples: + # Create a joint action chunking + joint_poses = [ + JointPose([0.0, 0.0, 0.0]), + JointPose([0.5, 0.5, 0.5]), + JointPose([1.0, 1.0, 1.0]), + ] + action_chunking = JointActionChunk(joint_poses) + + # Get relative trajectory (all poses relative to first pose) + relative_traj = action_chunking.relative_chunking() + + # Get relative trajectory with custom reference + reference = JointPose([0.1, 0.1, 0.1]) + relative_traj = action_chunking.relative_chunking(reference_frame=reference) + + # Get delta trajectory (incremental changes) + delta_traj = action_chunking.delta_chunking() + + # Convert relative trajectory back to absolute + reference = JointPose([0.1, 0.1, 0.1]) + absolute_traj = relative_traj.to_absolute_chunking(reference_frame=reference) + + # Interpolate trajectory + interpolated = action_chunking.interpolate(num_points=10) + + # Convert to desired format + from vllm_omni.diffusion.models.gr00t.dataio.types import ActionFormat + array_data = action_chunking.to(ActionFormat.DEFAULT) # Returns joint array + """ + + def __init__( + self, + poses: Sequence[JointPose], + times: Sequence[float] | NDArray[np.float64] | None = None, + ): + """ + Initialize a joint trajectory from a list of joint poses. + + Args: + poses: Sequence of JointPose objects + times: Optional sequence of timestamps for each pose + + Raises: + TypeError: If poses are not all JointPose objects + """ + # Validate all poses are JointPose + if not all(isinstance(p, JointPose) for p in poses): + raise TypeError("All poses must be JointPose objects for JointActionChunk") + + super().__init__(poses, times) + + def interpolate( + self, + num_points: int | None = None, + times: NDArray[np.float64] | None = None, + ) -> "JointActionChunk": + """ + Interpolate the joint action chunking to generate intermediate configurations. + + Uses linear interpolation for joint values. + + Args: + num_points: Number of evenly-spaced points to generate (including endpoints). + Only used if times is None. + times: Specific timestamps at which to interpolate. If provided, + num_points is ignored. + + Returns: + A new JointActionChunk with interpolated poses + + Raises: + ValueError: If neither num_points nor times is provided, or if + interpolation times are outside the trajectory range + """ + if num_points is None and times is None: + raise ValueError("Must provide either num_points or times") + + if len(self._poses) < 2: + raise ValueError("Need at least 2 poses for interpolation") + + # Prepare data: extract joint values + timestamps = self._times.copy() + joint_values = np.array([pose.joints for pose in self._poses]) # (N, num_joints) + + # Find and remove non-monotonic timestamps + drop_indices = [idx for idx in range(1, len(timestamps)) if timestamps[idx] <= timestamps[idx - 1]] + + if drop_indices: + for idx in drop_indices: + print( + f"Dropping timestamp pair - Previous: {timestamps[idx - 1]}, " + f"Current: {timestamps[idx]} at index {idx}" + ) + timestamps = np.delete(timestamps, drop_indices) + joint_values = np.delete(joint_values, drop_indices, axis=0) + + # Check if we still have enough poses after cleanup + if len(timestamps) < 2: + raise ValueError("Need at least 2 poses with monotonic timestamps for interpolation") + + # Create interpolator + joint_interp = interpolate.interp1d(timestamps, joint_values, kind="linear", axis=0) + + # Generate interpolation times if not provided + if times is None: + assert num_points is not None # Type narrowing for type checker + interp_times = np.linspace(timestamps[0], timestamps[-1], num_points) + else: + interp_times = np.array(times, dtype=np.float64) + + # Check that interpolation times are within bounds + if np.any(interp_times < timestamps[0]) or np.any(interp_times > timestamps[-1]): + raise ValueError(f"Interpolation times must be within [{timestamps[0]}, {timestamps[-1]}]") + + # Interpolate joint values + interp_joint_values = joint_interp(interp_times) + + # Create interpolated poses + joint_names = self._poses[0].joint_names + interpolated_poses = [ + JointPose(joints=interp_joint_values[i], joint_names=joint_names) for i in range(len(interp_times)) + ] + + return JointActionChunk(interpolated_poses, times=interp_times) + + def to_array(self) -> NDArray[np.float64]: + """ + Convert trajectory to array of joint values. + + Returns: + Array with shape (N, num_joints) where N is the number of poses + """ + return np.array([pose.joints for pose in self._poses]) + + def to_absolute_chunking(self, reference_frame: JointPose) -> "JointActionChunk": + """ + Convert a relative joint action chunking to an absolute action chunking by adding + the relative joint positions to the reference frame. + + This is the inverse operation of relative_chunking(). Each relative joint + configuration is added to the reference frame to produce absolute joint positions. + + Args: + reference_frame: The reference joint pose to apply the relative trajectory on top of. + + Returns: + A new JointActionChunk with absolute joint positions. + + Raises: + ValueError: If joint dimensions don't match + """ + if not self._poses: + return JointActionChunk([], times=[]) + + if len(self._poses[0].joints) != len(reference_frame.joints): + raise ValueError( + f"Cannot apply relative trajectory: " + f"joint dimensions don't match ({len(self._poses[0].joints)} vs {len(reference_frame.joints)})" + ) + + # Add each relative pose to the reference frame + absolute_poses: list[JointPose] = [] + for relative_pose in self._poses: + absolute_joints = reference_frame.joints + relative_pose.joints + absolute_pose = JointPose(joints=absolute_joints, joint_names=reference_frame.joint_names) + absolute_poses.append(absolute_pose) + + return JointActionChunk(absolute_poses, times=self.times) + + def to(self, action_format: ActionFormat) -> NDArray[np.float64]: + """ + Convert trajectory to the desired format. + + Args: + action_format: The desired output format + + Returns: + Array in the requested format + + Raises: + ValueError: If the action format is not supported for joint trajectories + """ + if action_format == ActionFormat.DEFAULT: + return self.to_array() + else: + raise ValueError( + f"ActionFormat {action_format} is not supported for JointActionChunk. " + f"Only {ActionFormat.DEFAULT} is supported." + ) + + +class EndEffectorActionChunk(ActionChunk[EndEffectorPose]): + """ + Represents action chunking in Cartesian space as a sequence of end-effector poses. + + Examples: + # Create an end-effector action chunking + ee_poses = [ + EndEffectorPose(translation=[0, 0, 0], rotation=[1, 0, 0, 0], + rotation_type="quat", rotation_order="wxyz"), + EndEffectorPose(translation=[1, 0, 0], rotation=[0.707, 0, 0, 0.707], + rotation_type="quat", rotation_order="wxyz"), + EndEffectorPose(translation=[2, 0, 0], rotation=[0, 0, 0, 1], + rotation_type="quat", rotation_order="wxyz"), + ] + action_chunking = EndEffectorActionChunk(ee_poses) + + # Get relative trajectory (all poses relative to first pose) + relative_traj = action_chunking.relative_chunking() + + # Get relative trajectory with custom reference frame + reference = EndEffectorPose(translation=[0.5, 0, 0], rotation=[1,0,0,0], + rotation_type="quat", rotation_order="wxyz") + relative_traj = action_chunking.relative_chunking(reference_frame=reference) + + # Get delta trajectory + delta_traj = action_chunking.delta_chunking() + + # Convert relative trajectory back to absolute + reference = EndEffectorPose(translation=[0.5, 0, 0], rotation=[1,0,0,0], + rotation_type="quat", rotation_order="wxyz") + absolute_traj = relative_traj.to_absolute_chunking(reference_frame=reference) + + # Interpolate trajectory + interpolated = action_chunking.interpolate(num_points=10) + + # Convert to desired format + from vllm_omni.diffusion.models.gr00t.dataio.types import ActionFormat + homo_matrices = action_chunking.to(ActionFormat.DEFAULT) # (N, 4, 4) homogeneous matrices + xyz_rot6d = action_chunking.to(ActionFormat.XYZ_ROT6D) # (N, 9) xyz + rot6d + xyz_rotvec = action_chunking.to(ActionFormat.XYZ_ROTVEC) # (N, 6) xyz + rotvec + """ + + def __init__( + self, + poses: Sequence[EndEffectorPose], + times: Sequence[float] | NDArray[np.float64] | None = None, + ): + """ + Initialize an end-effector trajectory from a list of end-effector poses. + + Args: + poses: Sequence of EndEffectorPose objects + times: Optional sequence of timestamps for each pose + + Raises: + TypeError: If poses are not all EndEffectorPose objects + """ + # Validate all poses are EndEffectorPose + if not all(isinstance(p, EndEffectorPose) for p in poses): + raise TypeError("All poses must be EndEffectorPose objects for EndEffectorActionChunk") + + super().__init__(poses, times) + + @classmethod + def from_array(cls, data: np.ndarray, action_format: ActionFormat) -> "EndEffectorActionChunk": + """ + Create an EndEffectorActionChunk from a 2-D array using the specified action format. + + This is the inverse of ``.to(action_format)``. + + Args: + data: Array of shape (N, D) where D depends on the action_format. + action_format: The format that describes the layout of each row. + + Returns: + EndEffectorActionChunk with N poses. + """ + poses = [EndEffectorPose.from_action_format(row, action_format) for row in data] + return cls(poses) + + def interpolate( + self, + num_points: int | None = None, + times: NDArray[np.float64] | None = None, + ) -> "EndEffectorActionChunk": + """ + Interpolate the action chunking to generate intermediate poses. + + Uses linear interpolation for translation and SLERP (Spherical Linear + Interpolation) for rotation. + + Args: + num_points: Number of evenly-spaced points to generate (including endpoints). + Only used if times is None. + times: Specific timestamps at which to interpolate. If provided, + num_points is ignored. + + Returns: + A new EndEffectorActionChunk with interpolated poses + + Raises: + ValueError: If neither num_points nor times is provided, or if + interpolation times are outside the trajectory range + """ + if num_points is None and times is None: + raise ValueError("Must provide either num_points or times") + + if len(self._poses) < 2: + raise ValueError("Need at least 2 poses for interpolation") + + # Prepare data: extract positions and rotations + timestamps = self._times.copy() + homogeneous_matrices = np.array([pose.homogeneous for pose in self._poses]) + positions = homogeneous_matrices[:, :3, 3] + rotations = Rotation.from_matrix(homogeneous_matrices[:, :3, :3]) + + # Find indices where timestamps are not monotonically increasing + drop_indices = [idx for idx in range(1, len(timestamps)) if timestamps[idx] <= timestamps[idx - 1]] + + # Remove the problematic timestamps and corresponding data + if drop_indices: + for idx in drop_indices: + print( + f"Dropping timestamp pair - Previous: {timestamps[idx - 1]}, " + f"Current: {timestamps[idx]} at index {idx}" + ) + timestamps = np.delete(timestamps, drop_indices) + positions = np.delete(positions, drop_indices, axis=0) + rotations = Rotation.from_matrix(np.delete(homogeneous_matrices[:, :3, :3], drop_indices, axis=0)) + + # Check if we still have enough poses after cleanup + if len(timestamps) < 2: + raise ValueError("Need at least 2 poses with monotonic timestamps for interpolation") + + # Create interpolators + pos_interp = interpolate.interp1d(timestamps, positions, kind="linear", axis=0) + rot_interp = Slerp(timestamps, rotations) + + # Generate interpolation times if not provided + if times is None: + assert num_points is not None # Type narrowing for type checker + interp_times = np.linspace(timestamps[0], timestamps[-1], num_points) + else: + interp_times = np.array(times, dtype=np.float64) + + # Check that interpolation times are within bounds + if np.any(interp_times < timestamps[0]) or np.any(interp_times > timestamps[-1]): + raise ValueError(f"Interpolation times must be within [{timestamps[0]}, {timestamps[-1]}]") + + # Interpolate positions and rotations + interp_positions = pos_interp(interp_times) + interp_rotations = rot_interp(interp_times) + + # Create interpolated poses + interpolated_poses = [] + for i in range(len(interp_times)): + pose = EndEffectorPose( + translation=interp_positions[i], + rotation=interp_rotations[i].as_matrix(), + rotation_type="matrix", + ) + interpolated_poses.append(pose) + + return EndEffectorActionChunk(interpolated_poses, times=interp_times) + + def to_homogeneous_matrices(self) -> NDArray[np.float64]: + """ + Convert trajectory to array of homogeneous transformation matrices. + + Returns: + Array of homogeneous matrices with shape (N, 4, 4) where N is the number of poses + """ + return np.array([pose.homogeneous for pose in self._poses]) + + def to_translation_rot6d(self) -> NDArray[np.float64]: + """ + Convert trajectory to array of translations and 6D rotations. + + Returns: + Array with shape (N, 9) - 3 for xyz + 6 for rot6d + """ + translations = np.array([pose.translation for pose in self._poses]) # (N, 3) + rotations = np.array([pose.rot6d for pose in self._poses]) # (N, 6) + + # Concatenate translation and rotation + xyz_rot6d = np.concatenate([translations, rotations], axis=1) # (N, 9) + + return xyz_rot6d + + def to_translation_rotvec(self) -> NDArray[np.float64]: + """ + Convert trajectory to array of translations and rotation vectors. + + Returns: + Array with shape (N, 6) - 3 for xyz + 3 for rotvec + """ + translations = np.array([pose.translation for pose in self._poses]) # (N, 3) + rotations = np.array([pose.rotvec for pose in self._poses]) # (N, 3) + + # Concatenate translation and rotation + xyz_rotvec = np.concatenate([translations, rotations], axis=1) # (N, 6) + + return xyz_rotvec + + def to_absolute_chunking(self, reference_frame: EndEffectorPose) -> "EndEffectorActionChunk": + """ + Convert a relative end-effector action chunking to an absolute action chunking by + composing each relative transformation with the reference frame. + + This is the inverse operation of relative_chunking(). Each relative pose + represents a transformation that is applied on top of the reference frame + to produce absolute poses. + + Args: + reference_frame: The reference end-effector pose to apply the relative trajectory on top of. + + Returns: + A new EndEffectorActionChunk with absolute poses. + """ + if not self._poses: + return EndEffectorActionChunk([], times=[]) + + # Get reference frame as homogeneous matrix + T_ref = reference_frame.homogeneous + + # Compose each relative transformation with the reference frame + absolute_poses: list[EndEffectorPose] = [] + for relative_pose in self._poses: + # Get relative transformation as homogeneous matrix + T_relative = relative_pose.homogeneous + + # Compose transformations: T_absolute = T_ref @ T_relative + T_absolute = T_ref @ T_relative + + # Create absolute pose from composed transformation + absolute_pose = EndEffectorPose(homogeneous=T_absolute) + absolute_poses.append(absolute_pose) + + return EndEffectorActionChunk(absolute_poses, times=self.times) + + def to(self, action_format: ActionFormat) -> NDArray[np.float64]: + """ + Convert trajectory to the desired format. + + Args: + action_format: The desired output format + + Returns: + Array in the requested format + + Raises: + ValueError: If the action format is not supported + """ + if action_format == ActionFormat.DEFAULT: + return self.to_homogeneous_matrices() + elif action_format == ActionFormat.XYZ_ROT6D: + return self.to_translation_rot6d() + elif action_format == ActionFormat.XYZ_ROTVEC: + return self.to_translation_rotvec() + else: + raise ValueError(f"Unsupported action format: {action_format}") diff --git a/vllm_omni/diffusion/models/gr00t/dataio/state_action/pose.py b/vllm_omni/diffusion/models/gr00t/dataio/state_action/pose.py new file mode 100644 index 00000000000..0e0849f17b4 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/state_action/pose.py @@ -0,0 +1,709 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from enum import Enum +from typing import TypeVar + +import numpy as np +from numpy.typing import NDArray +from scipy.spatial.transform import Rotation + +from vllm_omni.diffusion.models.gr00t.dataio.types import ActionFormat + +# TypeVar for self-type preservation in Pose operations +PoseT = TypeVar("PoseT", bound="Pose") + + +def invert_transformation(transform: NDArray[np.float64]) -> NDArray[np.float64]: + """ + Invert a homogeneous transformation matrix. + + Args: + transform: A 4x4 homogeneous transformation matrix + + Returns: + The inverse of the transformation matrix (4x4) + """ + R = transform[:3, :3] # Extract the rotation matrix + t = transform[:3, 3] # Extract the translation vector + + # Inverse of the rotation matrix is its transpose (since it's orthogonal) + R_inv = R.T + + # Inverse of the translation is -R_inv * t + t_inv = -R_inv @ t + + # Construct the inverse transformation matrix + T_inv = np.eye(4) + T_inv[:3, :3] = R_inv + T_inv[:3, 3] = t_inv + + return T_inv + + +def relative_transformation( + base_transform: NDArray[np.float64], target_transform: NDArray[np.float64] +) -> NDArray[np.float64]: + """ + Compute the relative transformation between two poses. + + Args: + base_transform: Initial 4x4 homogeneous transformation matrix + target_transform: Current 4x4 homogeneous transformation matrix + + Returns: + The relative transformation matrix (4x4) from base_transform to target_transform + """ + # Relative transformation is base_transform^{-1} * target_transform + T_relative = invert_transformation(base_transform) @ target_transform + return T_relative + + +class RotationType(Enum): + """Supported rotation representation types""" + + QUAT = "quat" + EULER = "euler" + ROTVEC = "rotvec" + MATRIX = "matrix" + ROT6D = "rot6d" + + +class EulerOrder(Enum): + """Common Euler angle conventions""" + + XYZ = "xyz" + ZYX = "zyx" + XZY = "xzy" + YXZ = "yxz" + YZX = "yzx" + ZXY = "zxy" + + +class QuatOrder(Enum): + """Quaternion ordering conventions""" + + WXYZ = "wxyz" # scalar-first (w, x, y, z) + XYZW = "xyzw" # scalar-last (x, y, z, w) + + +class Pose: + """ + Abstract base class for robot poses. + + This class provides common functionality for different pose representations + including relative pose computation via the subtraction operator. + """ + + pose_type: str + + def __sub__(self: PoseT, other: PoseT) -> PoseT: + """ + Compute relative transformation between two poses. + + For EndEffectorPose: Computes the relative transformation from other to self. + Result represents the transformation needed to go from other's frame to self's frame. + + For JointPose: Computes the joint-space difference (self - other). + + Args: + other: The reference pose to compute relative transformation from + + Returns: + Relative pose (same type as self) + + Raises: + TypeError: If poses are not of the same type + + Examples: + # End-effector poses + pose1 = EndEffectorPose(translation=[1, 0, 0], rotation=[1,0,0,0], + rotation_type="quat", rotation_order="wxyz") + pose2 = EndEffectorPose(translation=[2, 0, 0], rotation=[1,0,0,0], + rotation_type="quat", rotation_order="wxyz") + relative = pose2 - pose1 # Transformation from pose1 to pose2 + + # Joint poses + joint1 = JointPose([0.0, 0.5, 1.0]) + joint2 = JointPose([0.1, 0.6, 1.2]) + joint_diff = joint2 - joint1 # Joint differences: [0.1, 0.1, 0.2] + """ + if type(self) is not type(other): + raise TypeError( + f"Cannot compute relative transformation between different pose types: " + f"{type(self).__name__} and {type(other).__name__}" + ) + + return self._compute_relative(other) + + def _compute_relative(self: PoseT, other: PoseT) -> PoseT: + """ + Internal method to compute relative transformation. + Must be implemented by subclasses. + + Args: + other: The reference pose + + Returns: + Relative pose + """ + raise NotImplementedError("Subclasses must implement _compute_relative") + + def copy(self: PoseT) -> PoseT: + """ + Create a deep copy of this pose. + Must be implemented by subclasses. + + Returns: + New Pose instance with copied data + """ + raise NotImplementedError("Subclasses must implement copy") + + +class JointPose(Pose): + """ + Represents a robot configuration in joint space. + + This class stores joint angles/positions for a robot manipulator. + Unlike end-effector poses, joint poses represent the configuration + of all joints in the kinematic chain. + + Examples: + # Create a 6-DOF joint configuration + joint_pose = JointPose( + joints=[0.0, -np.pi/4, np.pi/2, 0.0, np.pi/4, 0.0], + joint_names=["shoulder_pan", "shoulder_lift", "elbow", + "wrist_1", "wrist_2", "wrist_3"] + ) + + # Create with default joint names + joint_pose = JointPose(joints=[0.0, 0.5, 1.0]) + + # Get as dictionary + joint_dict = joint_pose.to_dict() # {"joint_0": 0.0, ...} + + # Access individual joints + first_joint = joint_pose.joints[0] + num_joints = joint_pose.num_joints + + # Compute relative joint displacement + joint1 = JointPose([0.0, 0.5, 1.0]) + joint2 = JointPose([0.1, 0.6, 1.2]) + relative = joint2 - joint1 # [0.1, 0.1, 0.2] + """ + + pose_type = "joint" + + def __init__( + self, + joints: list | np.ndarray, + joint_names: list | None = None, + ): + """ + Initialize a joint pose. + + Args: + joints: Joint angles/positions as array-like of shape (n,) + joint_names: Optional list of names for each joint. If None, + defaults to ["joint_0", "joint_1", ...] + """ + super().__init__() + self.joints = np.array(joints, dtype=np.float64) + + # Set defaults and validate joint_names + if joint_names is None: + self.joint_names = [f"joint_{i}" for i in range(len(self.joints))] + else: + if len(joint_names) != len(self.joints): + raise ValueError( + f"Number of joint names ({len(joint_names)}) must match number of joints ({len(self.joints)})" + ) + self.joint_names = joint_names + + @property + def num_joints(self) -> int: + """ + Get the number of joints. + + Returns: + Number of joints in the configuration + """ + return len(self.joints) + + def to_dict(self) -> dict: + """ + Convert joint configuration to dictionary. + + Returns: + Dictionary mapping joint names to joint values + """ + return dict(zip(self.joint_names, self.joints)) + + def _compute_relative(self, other): # type: ignore[override] + """ + Compute relative joint displacement. + + Args: + other: Reference joint pose + + Returns: + JointPose representing the joint-space difference (self - other) + + Raises: + ValueError: If joint dimensions don't match + """ + if len(self.joints) != len(other.joints): + raise ValueError( + f"Cannot compute relative joint pose: " + f"joint dimensions don't match ({len(self.joints)} vs {len(other.joints)})" + ) + + relative_joints = self.joints - other.joints + return JointPose(joints=relative_joints, joint_names=self.joint_names) + + def copy(self) -> "JointPose": + """ + Create a deep copy of this joint pose. + + Returns: + New JointPose instance with copied data + """ + return JointPose( + joints=self.joints.copy(), + joint_names=self.joint_names.copy(), + ) + + def __repr__(self) -> str: + if len(self.joints) <= 6: + joints_str = np.array2string(self.joints, precision=4, suppress_small=True) + else: + joints_str = f"[{self.joints[0]:.4f}, ..., {self.joints[-1]:.4f}] ({len(self.joints)} joints)" + + return f"JointPose(joints={joints_str})" + + def __eq__(self, other) -> bool: + if not isinstance(other, JointPose): + return False + return np.allclose(self.joints, other.joints) and self.joint_names == other.joint_names + + def __getitem__(self, index) -> float | NDArray[np.float64]: + """Allow indexing: joint_pose[0] returns first joint value""" + return self.joints[index] + + def __len__(self) -> int: + """Allow len(): len(joint_pose) returns number of joints""" + return len(self.joints) + + +class EndEffectorPose(Pose): + """ + Represents a single end-effector pose with translation and rotation components. + + This class handles Cartesian space representations of robot end-effector poses, + supporting multiple rotation representations (quaternions, Euler angles, rotation + vectors, rotation matrices, etc.). + + Examples: + # Create with quaternion (wxyz order) + pose = EndEffectorPose( + translation=[1.0, 2.0, 3.0], + rotation=[1.0, 0.0, 0.0, 0.0], + rotation_type="quat", + rotation_order="wxyz" + ) + + # Create with Euler angles (degrees by default) + pose = EndEffectorPose( + translation=[1, 2, 3], + rotation=[0, 0, 90], + rotation_type="euler", + rotation_order="xyz" + ) + + # Create with Euler angles in radians + pose = EndEffectorPose( + translation=[1, 2, 3], + rotation=[0, 0, np.pi/2], + rotation_type="euler", + rotation_order="xyz", + degrees=False + ) + + # Create from homogeneous matrix + H = np.eye(4) + H[:3, 3] = [1, 2, 3] + pose = EndEffectorPose(homogeneous=H) + + # Convert between representations + quat_wxyz = pose.to_rotation("quat", "wxyz") + euler_zyx = pose.to_rotation("euler", "zyx") + rot6d = pose.to_rotation("rot6d") + + # Compute relative transformation + pose1 = EndEffectorPose(translation=[1, 0, 0], rotation=[1,0,0,0], + rotation_type="quat", rotation_order="wxyz") + pose2 = EndEffectorPose(translation=[2, 0, 0], rotation=[1,0,0,0], + rotation_type="quat", rotation_order="wxyz") + relative = pose2 - pose1 # Transformation from pose1's frame to pose2's frame + """ + + pose_type = "end_effector" + + def __init__( + self, + translation: list | np.ndarray | None = None, + rotation: list | np.ndarray | None = None, + rotation_type: str | None = None, + rotation_order: str | None = None, + homogeneous: np.ndarray | None = None, + degrees: bool = True, + ): + """ + Initialize an end-effector pose. + + Args: + translation: Translation vector [x, y, z] + rotation: Rotation in specified format + rotation_type: Type of rotation ("quat", "euler", "rotvec", "matrix", "rot6d") + rotation_order: Order/convention for the rotation type + homogeneous: Homogeneous transformation matrix (4, 4) + If provided, overrides translation and rotation + degrees: For Euler angles, whether the input is in degrees (default True) + """ + super().__init__() + + # Cache for homogeneous matrix + self._homogeneous_cache: NDArray[np.float64] | None = None + self._cache_valid = False + + # Handle homogeneous matrix input + if homogeneous is not None: + self._from_homogeneous(homogeneous) + return + + # Store translation + self._translation = np.array(translation) if translation is not None else np.zeros(3) + + # Store rotation as scipy Rotation object internally + if rotation is not None: + if rotation_type is None: + raise ValueError("rotation_type must be specified when rotation is provided") + self._set_rotation(rotation, rotation_type, rotation_order, degrees) + else: + self._rotation = Rotation.identity() + + def _from_homogeneous(self, homogeneous: np.ndarray): + """Initialize from homogeneous transformation matrix""" + homogeneous = np.array(homogeneous) + + # Extract translation (last column, first 3 rows) + self._translation = homogeneous[:3, 3] + + # Extract rotation matrix (top-left 3x3) + rotation_matrix = homogeneous[:3, :3] + + # Create Rotation object from matrix + self._rotation = Rotation.from_matrix(rotation_matrix) + + @staticmethod + def _rot6d_to_matrix(rot6d: np.ndarray) -> np.ndarray: + """ + Convert 6D rotation representation to rotation matrix. + + Args: + rot6d: 6D rotation as (6,) array - first two rows of rotation matrix flattened + + Returns: + Rotation matrix (3, 3) + """ + rot6d = rot6d.reshape(2, 3) + + # First two rows + row1 = rot6d[0] + row2 = rot6d[1] + + # Normalize first row + row1 = row1 / np.linalg.norm(row1) + + # Gram-Schmidt orthogonalization for second row + row2 = row2 - np.dot(row1, row2) * row1 + row2 = row2 / np.linalg.norm(row2) + + # Third row is cross product + row3 = np.cross(row1, row2) + + # Construct rotation matrix + rotation_matrix = np.vstack([row1, row2, row3]) + + return rotation_matrix + + @staticmethod + def _matrix_to_rot6d(rotation_matrix: np.ndarray) -> np.ndarray: + """ + Convert rotation matrix to 6D rotation representation. + + Args: + rotation_matrix: Rotation matrix (3, 3) + + Returns: + 6D rotation - (6,) array (first two rows flattened) + """ + return rotation_matrix[:2, :].flatten() + + def _set_rotation( + self, + rotation: list | np.ndarray, + rotation_type: str, + rotation_order: str | None = None, + degrees: bool = True, + ): + """Internal method to set rotation from various representations""" + rotation = np.array(rotation) + rot_type = RotationType(rotation_type.lower()) + + if rot_type == RotationType.QUAT: + quat_order = QuatOrder(rotation_order.lower()) if rotation_order else QuatOrder.WXYZ + if quat_order == QuatOrder.WXYZ: + # scipy uses xyzw order, so convert + quat_xyzw = np.array([rotation[1], rotation[2], rotation[3], rotation[0]]) + else: + quat_xyzw = rotation + self._rotation = Rotation.from_quat(quat_xyzw) + + elif rot_type == RotationType.EULER: + euler_order = EulerOrder(rotation_order.lower()) if rotation_order else EulerOrder.XYZ + self._rotation = Rotation.from_euler(euler_order.value, rotation, degrees=degrees) + + elif rot_type == RotationType.ROTVEC: + self._rotation = Rotation.from_rotvec(rotation) + + elif rot_type == RotationType.MATRIX: + self._rotation = Rotation.from_matrix(rotation) + + elif rot_type == RotationType.ROT6D: + rotation_matrix = self._rot6d_to_matrix(rotation) + self._rotation = Rotation.from_matrix(rotation_matrix) + + else: + raise ValueError(f"Unknown rotation type: {rotation_type}") + + # Invalidate cache + self._cache_valid = False + + @property + def translation(self) -> np.ndarray: + """ + Get translation vector. + + Returns: + Translation array - shape (3,) + """ + return self._translation.copy() + + @property + def quat_wxyz(self) -> np.ndarray: + """Get rotation as quaternion in wxyz order (w, x, y, z)""" + return self.to_rotation("quat", "wxyz") + + @property + def quat_xyzw(self) -> np.ndarray: + """Get rotation as quaternion in xyzw order (x, y, z, w)""" + return self.to_rotation("quat", "xyzw") + + @property + def euler_xyz(self) -> np.ndarray: + """Get rotation as Euler angles in xyz order (degrees)""" + return self.to_rotation("euler", "xyz") + + @property + def rotvec(self) -> np.ndarray: + """Get rotation as rotation vector (axis-angle)""" + return self.to_rotation("rotvec") + + @property + def rotation_matrix(self) -> np.ndarray: + """Get rotation as 3x3 rotation matrix""" + return self.to_rotation("matrix") + + @property + def rot6d(self) -> np.ndarray: + """Get rotation as 6D representation (first two rows of rotation matrix)""" + return self.to_rotation("rot6d") + + @property + def xyz_rot6d(self) -> np.ndarray: + """Get pose as concatenated translation and 6D rotation (9,)""" + return np.concatenate([self._translation, self.rot6d]) + + @property + def xyz_rotvec(self) -> np.ndarray: + """Get pose as concatenated translation and rotation vector (6,)""" + return np.concatenate([self._translation, self.rotvec]) + + @property + def homogeneous(self) -> np.ndarray: + """ + Get homogeneous transformation matrix. + + Returns: + Homogeneous matrix - shape (4, 4) + """ + if not self._cache_valid: + self._homogeneous_cache = self._compute_homogeneous() + self._cache_valid = True + assert self._homogeneous_cache is not None + return self._homogeneous_cache.copy() + + def _compute_homogeneous(self) -> np.ndarray: + """Compute homogeneous transformation matrix""" + H = np.eye(4) + H[:3, :3] = self._rotation.as_matrix() + H[:3, 3] = self._translation + return H + + def to_rotation( + self, + rotation_type: str, + rotation_order: str | None = None, + degrees: bool = True, + ) -> np.ndarray: + """ + Get rotation in specified representation. + + Args: + rotation_type: Desired type ("quat", "euler", "rotvec", "matrix", "rot6d") + rotation_order: Order/convention for the rotation type + degrees: For Euler angles, return in degrees (default True) + + Returns: + Rotation in requested format + - Shape (4,) for quat + - Shape (3,) for euler/rotvec + - Shape (6,) for rot6d + - Shape (3, 3) for matrix + """ + rot_type = RotationType(rotation_type.lower()) + + if rot_type == RotationType.ROT6D: + rotation_matrix = self._rotation.as_matrix() + return self._matrix_to_rot6d(rotation_matrix) + + if rot_type == RotationType.QUAT: + quat_order = QuatOrder(rotation_order.lower()) if rotation_order else QuatOrder.WXYZ + quat_xyzw = self._rotation.as_quat() + if quat_order == QuatOrder.WXYZ: + return np.array([quat_xyzw[3], quat_xyzw[0], quat_xyzw[1], quat_xyzw[2]]) + else: + return quat_xyzw + + elif rot_type == RotationType.EULER: + euler_order = EulerOrder(rotation_order.lower()) if rotation_order else EulerOrder.XYZ + return self._rotation.as_euler(euler_order.value, degrees=degrees) + + elif rot_type == RotationType.ROTVEC: + return self._rotation.as_rotvec() + + elif rot_type == RotationType.MATRIX: + return self._rotation.as_matrix() + + else: + raise ValueError(f"Unknown rotation type: {rotation_type}") + + def to_homogeneous(self) -> np.ndarray: + """ + Convert pose to homogeneous transformation matrix. + (Alias for the homogeneous property) + + Returns: + Homogeneous matrix - shape (4, 4) + """ + return self.homogeneous + + def set_rotation( + self, + rotation: list | np.ndarray, + rotation_type: str, + rotation_order: str | None = None, + degrees: bool = True, + ): + """ + Set rotation from specified representation. + + Args: + rotation: Rotation data + rotation_type: Type of rotation ("quat", "euler", "rotvec", "matrix", "rot6d") + rotation_order: Order/convention for the rotation type + degrees: For Euler angles, whether the input is in degrees (default True) + """ + self._set_rotation(rotation, rotation_type, rotation_order, degrees) + + def _compute_relative(self, other): # type: ignore[override] + """ + Compute relative transformation from other to self. + + The result represents the transformation needed to go from other's frame to self's frame. + Mathematically: T_relative = T_other^{-1} * T_self + + Args: + other: Reference end-effector pose + + Returns: + EndEffectorPose representing the relative transformation + """ + # Get homogeneous matrices + T_self = self.homogeneous + T_other = other.homogeneous + + # Compute relative transformation: T_other^{-1} * T_self + T_relative = relative_transformation(T_other, T_self) + + # Create new EndEffectorPose from relative transformation + return EndEffectorPose(homogeneous=T_relative) + + @classmethod + def from_action_format(cls, data: np.ndarray, action_format: ActionFormat) -> "EndEffectorPose": + """ + Create an EndEffectorPose from a flat array using the specified action format. + + This is the inverse of the xyz_rot6d / xyz_rotvec / homogeneous properties. + + Args: + data: Flat array whose layout depends on action_format. + action_format: One of ActionFormat.XYZ_ROT6D, XYZ_ROTVEC, or DEFAULT. + + Returns: + EndEffectorPose instance. + """ + if action_format == ActionFormat.XYZ_ROT6D: + return cls(translation=data[:3], rotation=data[3:], rotation_type="rot6d") + elif action_format == ActionFormat.XYZ_ROTVEC: + return cls(translation=data[:3], rotation=data[3:], rotation_type="rotvec") + elif action_format == ActionFormat.DEFAULT: + return cls(homogeneous=data.reshape(4, 4)) + else: + raise ValueError(f"Unsupported ActionFormat: {action_format}") + + def copy(self) -> "EndEffectorPose": + """ + Create a deep copy of this end-effector pose. + + Returns: + New EndEffectorPose instance with copied data + """ + return EndEffectorPose( + translation=self._translation.copy(), + rotation=self._rotation.as_quat(), + rotation_type="quat", + rotation_order="xyzw", + ) + + def __repr__(self) -> str: + quat = self.to_rotation("quat", "wxyz") + return f"EndEffectorPose(translation={self.translation}, rotation_quat_wxyz={quat})" + + def __eq__(self, other) -> bool: + if not isinstance(other, EndEffectorPose): + return False + return np.allclose(self._translation, other._translation) and np.allclose( + self._rotation.as_quat(), other._rotation.as_quat() + ) diff --git a/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py b/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py new file mode 100644 index 00000000000..14e2b0b3a37 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py @@ -0,0 +1,643 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Unified state and action processor for robotics.""" + +from copy import deepcopy + +import numpy as np + +from vllm_omni.diffusion.models.gr00t.configs.embodiment_configs import ( + ActionFormat, + ActionRepresentation, + ActionType, + ModalityConfig, +) +from vllm_omni.diffusion.models.gr00t.dataio.state_action.action_chunking import ( + EndEffectorActionChunk, + JointActionChunk, +) +from vllm_omni.diffusion.models.gr00t.dataio.state_action.pose import EndEffectorPose, JointPose +from vllm_omni.diffusion.models.gr00t.dataio.utils import ( + apply_sin_cos_encoding, + nested_dict_to_numpy, + normalize_values_meanstd, + normalize_values_minmax, + parse_modality_configs, + unnormalize_values_meanstd, + unnormalize_values_minmax, +) + + +class StateActionProcessor: + """ + Unified processor for robot state and action data. + + Handles: + - State normalization (min/max, mean/std, sin/cos encoding) + - Action normalization + - Absolute <-> Relative action representation conversion + - Action processing with state dependency + """ + + def __init__( + self, + modality_configs: dict[str, dict[str, ModalityConfig]], + statistics: (dict[str, dict[str, dict[str, dict[str, list[float]]]]] | None) = None, + use_percentiles: bool = False, + clip_outliers: bool = True, + apply_sincos_state_encoding: bool = False, + use_relative_action: bool = False, + ): + """ + Initialize unified state and action processor. + + Args: + modality_configs: Nested dict with structure: + {embodiment_tag: {modality: ModalityConfig}} + where modality in ["state", "action"] + Example: {"gr1": {"state": ModalityConfig(...), "action": ModalityConfig(...)}} + statistics: Optional nested dict with structure: + {embodiment_tag: {modality: {joint_group: {stat_type: values}}}} + where modality in ["state", "action", "relative_action"] + and stat_type in ["min", "max", "mean", "std", "q01", "q99"] + Example: {"gr1": {"state": {"left_arm": {"min": [...], "max": [...], ...}}}} + use_percentiles: Whether to use percentiles (q01/q99) instead of min/max + clip_outliers: Whether to clip normalized values to [-1, 1] + apply_sincos_state_encoding: Global flag to enable sin/cos encoding for states + """ + self.modality_configs = parse_modality_configs(modality_configs) + self.statistics: dict[str, dict[str, dict[str, dict[str, list[float]]]]] = {} + self.use_percentiles = use_percentiles + self.clip_outliers = clip_outliers + self.apply_sincos_state_encoding = apply_sincos_state_encoding + self.use_relative_action = use_relative_action + + # Normalization parameters computed from statistics + self.norm_params: dict[str, dict[str, dict[str, dict[str, np.ndarray]]]] = {} + # Format: norm_params[embodiment_tag][modality][joint_group][stat_type] + # where stat_type in ["min", "max", "mean", "std", "dim"] + + if statistics is not None: + self.set_statistics(statistics) + + def set_statistics( + self, + statistics: dict[str, dict[str, dict[str, dict[str, list[float]]]]], + override: bool = False, + ) -> None: + """ + Set dataset statistics for normalization. + + Args: + statistics: Nested dict with structure: + {embodiment_tag: {modality: {joint_group: {stat_type: values}}}} + """ + for key in statistics: + if key not in self.statistics or override: + self.statistics[key] = deepcopy(statistics[key]) + else: + print(f"Embodiment tag {key} already in statistics, skipping updating") + self._compute_normalization_parameters() + + def _compute_normalization_parameters(self) -> None: + """Compute and cache normalization parameters from statistics for all embodiments and modalities.""" + for embodiment_tag in self.statistics: + self.norm_params[embodiment_tag] = {} + + for modality in ["state", "action"]: + if modality not in self.statistics[embodiment_tag]: + continue + + self.norm_params[embodiment_tag][modality] = {} + + for joint_group, stats in self.statistics[embodiment_tag][modality].items(): + if self.use_percentiles: + min_vals = np.array(stats["q01"]) + max_vals = np.array(stats["q99"]) + else: + min_vals = np.array(stats["min"]) + max_vals = np.array(stats["max"]) + + mean_vals = np.array(stats["mean"]) + std_vals = np.array(stats["std"]) + + # Compute range, ensuring it's not zero + range_vals = max_vals - min_vals + range_vals = np.maximum(range_vals, 1e-8) + + self.norm_params[embodiment_tag][modality][joint_group] = { + "min": min_vals, + "max": max_vals, + "dim": np.array(range_vals.shape[0]), + "mean": mean_vals, + "std": std_vals, + } + + # Override absolute action stats with relative stats where specified + if "action" in self.modality_configs[embodiment_tag]: + modality_keys = self.modality_configs[embodiment_tag]["action"].modality_keys + action_configs = self.modality_configs[embodiment_tag]["action"].action_configs + + if action_configs is not None: + for key, action_config in zip(modality_keys, action_configs): + if action_config.rep == ActionRepresentation.RELATIVE and self.use_relative_action: + if "relative_action" not in self.statistics[embodiment_tag]: + raise ValueError( + f"Relative action statistics required for embodiment '{embodiment_tag}' " + f"but 'relative_action' not found in statistics" + ) + if key not in self.statistics[embodiment_tag]["relative_action"]: + raise ValueError( + f"Relative action statistics required for key '{key}' " + f"in embodiment '{embodiment_tag}' but not found" + ) + action_dim = self.norm_params[embodiment_tag]["action"][key]["dim"] + self.norm_params[embodiment_tag]["action"][key] = nested_dict_to_numpy( + self.statistics[embodiment_tag]["relative_action"][key] + ) + self.norm_params[embodiment_tag]["action"][key]["dim"] = action_dim + + def apply_state( + self, + state: dict[str, np.ndarray], + embodiment_tag: str, + ) -> dict[str, np.ndarray]: + """ + Apply state processing (normalization, encoding). + + Args: + state: Dict mapping joint_group -> raw state values + Shape per group: (..., D) where D is state dimension + embodiment_tag: Embodiment identifier (e.g., "gr1") + + Returns: + Dict mapping joint_group -> processed state values + - Sin/cos encoded groups: (..., 2*D) + - Other groups: (..., D) + """ + normalized_values = {} + state = deepcopy(state) # Avoid modifying input + + # Get sin/cos embedding keys if enabled + sin_cos_keys = None + if self.apply_sincos_state_encoding: + state_config = self.modality_configs[embodiment_tag].get("state") + if state_config and hasattr(state_config, "sin_cos_embedding_keys"): + sin_cos_keys = state_config.sin_cos_embedding_keys + + for joint_group in self.modality_configs[embodiment_tag]["state"].modality_keys: + if joint_group not in state: + raise KeyError(f"Joint group '{joint_group}' not found in state dict for embodiment '{embodiment_tag}'") + + # Strategy 1: Sin/cos encoding (doubles dimension) + if sin_cos_keys and joint_group in sin_cos_keys: + normalized_values[joint_group] = apply_sin_cos_encoding(state[joint_group]) + + # Strategy 2: Mean/std normalization + elif ( + hasattr( + self.modality_configs[embodiment_tag]["state"], + "mean_std_embedding_keys", + ) + and self.modality_configs[embodiment_tag]["state"].mean_std_embedding_keys + and joint_group in self.modality_configs[embodiment_tag]["state"].mean_std_embedding_keys + ): + params = self.norm_params[embodiment_tag]["state"][joint_group] + normalized = normalize_values_meanstd(state[joint_group], params) + normalized_values[joint_group] = normalized + + # Strategy 3: Min/max normalization to [-1, 1] + else: + params = self.norm_params[embodiment_tag]["state"][joint_group] + normalized = normalize_values_minmax(state[joint_group], params) + + if self.clip_outliers: + normalized = np.clip(normalized, -1.0, 1.0) + + normalized_values[joint_group] = normalized + + return normalized_values + + def unapply_state( + self, + state: dict[str, np.ndarray], + embodiment_tag: str, + ) -> dict[str, np.ndarray]: + """ + Reverse state processing (denormalization). + + Args: + state: Dict mapping joint_group -> processed state values + embodiment_tag: Embodiment identifier + + Returns: + Dict mapping joint_group -> raw state values + + Raises: + ValueError: If attempting to reverse sin/cos encoding (not reversible) + """ + unnormalized_values = {} + + # Get sin/cos embedding keys if enabled + sin_cos_keys = None + if self.apply_sincos_state_encoding: + state_config = self.modality_configs[embodiment_tag].get("state") + if state_config and hasattr(state_config, "sin_cos_embedding_keys"): + sin_cos_keys = state_config.sin_cos_embedding_keys + + for joint_group in self.modality_configs[embodiment_tag]["state"].modality_keys: + if joint_group not in state: + raise KeyError(f"Joint group '{joint_group}' not found in state dict for embodiment '{embodiment_tag}'") + + # Sin/cos encoding is not reversible + if sin_cos_keys and joint_group in sin_cos_keys: + raise ValueError( + f"Cannot unapply sin/cos encoding for joint group '{joint_group}' " + f"in embodiment '{embodiment_tag}'. This transformation is not reversible." + ) + + # Reverse mean/std normalization + elif ( + hasattr( + self.modality_configs[embodiment_tag]["state"], + "mean_std_embedding_keys", + ) + and self.modality_configs[embodiment_tag]["state"].mean_std_embedding_keys + and joint_group in self.modality_configs[embodiment_tag]["state"].mean_std_embedding_keys + ): + params = self.norm_params[embodiment_tag]["state"][joint_group] + unnormalized = unnormalize_values_meanstd(state[joint_group], params) + unnormalized_values[joint_group] = unnormalized + + # Reverse min/max normalization + else: + params = self.norm_params[embodiment_tag]["state"][joint_group] + unnormalized_values[joint_group] = unnormalize_values_minmax(state[joint_group], params) + + return unnormalized_values + + def apply_action( + self, + action: dict[str, np.ndarray], + embodiment_tag: str, + state: dict[str, np.ndarray] | None = None, + ) -> dict[str, np.ndarray]: + """ + Apply action processing (absolute->relative conversion, normalization). + + Processing order: + 1. Convert absolute actions to relative (if configured) + 2. Normalize actions + + Args: + action: Dict mapping joint_group -> raw action values + Shape per group: (T, D) where T is action horizon, D is action dimension + embodiment_tag: Embodiment identifier + state: Optional dict mapping joint_group -> raw state values + Required if any action group uses ActionRepresentation.RELATIVE + Shape per group: (T_state, D) where last timestep is used as reference + + Returns: + Dict mapping joint_group -> processed action values + Shape per group: (T, D) + + Raises: + ValueError: If state is None but required for relative action conversion + """ + action = deepcopy(action) # Avoid modifying input + + # Step 1: Convert absolute actions to relative (if needed) + modality_keys = self.modality_configs[embodiment_tag]["action"].modality_keys + action_configs = self.modality_configs[embodiment_tag]["action"].action_configs + + if action_configs is not None: + for key, action_config in zip(modality_keys, action_configs): + if action_config.rep == ActionRepresentation.RELATIVE and self.use_relative_action: + if state is None: + raise ValueError( + f"State dict required for relative action processing of key '{key}' " + f"in embodiment '{embodiment_tag}'" + ) + + # Determine which state key to use as reference + state_key = action_config.state_key if action_config.state_key else key + + if state_key not in state: + raise KeyError( + f"Reference state key '{state_key}' not found in state dict " + f"for embodiment '{embodiment_tag}'" + ) + + # Use last state as reference frame + reference_state = state[state_key][-1] + + # Convert absolute to relative + action[key] = self._convert_to_relative_action( + action=action[key], + reference_state=reference_state, + action_type=action_config.type, + action_format=action_config.format, + ) + + # Step 2: Normalize actions + normalized_values = {} + for joint_group in modality_keys: + if joint_group not in action: + raise KeyError( + f"Joint group '{joint_group}' not found in action dict for embodiment '{embodiment_tag}'" + ) + + params = self.norm_params[embodiment_tag]["action"][joint_group] + if ( + self.modality_configs[embodiment_tag]["action"].mean_std_embedding_keys is not None + and joint_group in self.modality_configs[embodiment_tag]["action"].mean_std_embedding_keys + ): + normalized = normalize_values_meanstd(action[joint_group], params) + else: + normalized = normalize_values_minmax(action[joint_group], params) + + if self.clip_outliers: + normalized = np.clip(normalized, -1.0, 1.0) + + normalized_values[joint_group] = normalized + + return normalized_values + + def unapply_action( + self, + action: dict[str, np.ndarray], + embodiment_tag: str, + state: dict[str, np.ndarray] | None = None, + ) -> dict[str, np.ndarray]: + """ + Reverse action processing (denormalization, relative->absolute conversion). + + Processing order: + 1. Denormalize actions + 2. Convert relative actions to absolute (if configured) + + Args: + action: Dict mapping joint_group -> processed action values + Shape per group: (T, D) or (B, T, D) for batched + embodiment_tag: Embodiment identifier + state: Optional dict mapping joint_group -> raw state values + Required if any action group uses ActionRepresentation.RELATIVE + Shape per group: (T_state, D) or (B, T_state, D) for batched + + Returns: + Dict mapping joint_group -> raw absolute action values + Shape per group: (T, D) or (B, T, D) for batched + + Raises: + ValueError: If state is None but required for relative->absolute conversion + """ + # Step 1: Unnormalize actions + unnormalized_values = {} + modality_keys = self.modality_configs[embodiment_tag]["action"].modality_keys + + for joint_group in modality_keys: + if joint_group not in action: + raise KeyError( + f"Joint group '{joint_group}' not found in action dict for embodiment '{embodiment_tag}'" + ) + + params = self.norm_params[embodiment_tag]["action"][joint_group] + group_values = action[joint_group] + + if ( + self.modality_configs[embodiment_tag]["action"].mean_std_embedding_keys is not None + and joint_group in self.modality_configs[embodiment_tag]["action"].mean_std_embedding_keys + ): + unnormalized = unnormalize_values_meanstd(group_values, params) + else: + unnormalized = unnormalize_values_minmax(group_values, params) + + unnormalized_values[joint_group] = unnormalized + + # Step 2: Convert relative actions to absolute (if needed) + action_configs = self.modality_configs[embodiment_tag]["action"].action_configs + + if action_configs is not None: + for key, action_config in zip(modality_keys, action_configs): + if action_config.rep == ActionRepresentation.RELATIVE and self.use_relative_action: + if state is None: + raise ValueError( + f"State dict required for relative->absolute conversion of key '{key}' " + f"in embodiment '{embodiment_tag}'" + ) + + # Determine which state key to use as reference + state_key = action_config.state_key if action_config.state_key else key + + if state_key not in state: + raise KeyError( + f"Reference state key '{state_key}' not found in state dict " + f"for embodiment '{embodiment_tag}'" + ) + + relative_action = unnormalized_values[key] + + # Handle batched and unbatched cases + is_batched = relative_action.ndim == 3 + if not is_batched: + assert relative_action.ndim == 2 + reference_state = state[state_key] + if reference_state.ndim == 2: + reference_state = reference_state[None, :] + relative_action = relative_action[None, :] + else: + reference_state = state[state_key] + if reference_state.ndim == 2: + reference_state = reference_state[None, :] + + # Convert batched relative actions to absolute + absolute_actions = [] + for s, a in zip(reference_state, relative_action): + # Use last timestep of state as reference + absolute_action = self._convert_to_absolute_action( + action=a, + reference_state=s[-1], + action_type=action_config.type, + action_format=action_config.format, + ) + absolute_actions.append(absolute_action) + + if is_batched: + unnormalized_values[key] = np.stack(absolute_actions, axis=0) + else: + unnormalized_values[key] = absolute_actions[0] + + return unnormalized_values + + def apply( + self, + state: dict[str, np.ndarray], + action: dict[str, np.ndarray], + embodiment_tag: str, + ) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: + """ + Apply both state and action processing together. + + Convenience method that processes state and action in one call, + automatically passing raw state to action processor for relative conversion. + + Args: + state: Dict mapping joint_group -> raw state values + action: Dict mapping joint_group -> raw action values + embodiment_tag: Embodiment identifier + + Returns: + Tuple of (processed_state, processed_action) + """ + processed_state = self.apply_state(state, embodiment_tag) + if action: + processed_action = self.apply_action(action, embodiment_tag, state=state) + else: + processed_action = {} + return processed_state, processed_action + + def unapply( + self, + state: dict[str, np.ndarray], + action: dict[str, np.ndarray], + embodiment_tag: str, + raw_state: dict[str, np.ndarray] | None = None, + ) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: + """ + Reverse both state and action processing together. + + Args: + state: Dict mapping joint_group -> processed state values + action: Dict mapping joint_group -> processed action values + embodiment_tag: Embodiment identifier + raw_state: Optional dict of raw states for relative->absolute conversion + If None, will use unapplied state (but won't work for sin/cos encoded states) + + Returns: + Tuple of (raw_state, raw_action) + """ + # Unapply state first + try: + unapplied_state = self.unapply_state(state, embodiment_tag) + except ValueError as e: + if "sin/cos encoding" in str(e) and raw_state is None: + raise ValueError("Cannot unapply sin/cos encoded state. Please provide raw_state parameter.") from e + raise + + # Use provided raw_state if available, otherwise use unapplied state + state_for_action = raw_state if raw_state is not None else unapplied_state + + # Unapply action + unapplied_action = self.unapply_action(action, embodiment_tag, state=state_for_action) + + return unapplied_state, unapplied_action + + def get_state_dim(self, embodiment_tag: str, include_sincos_expansion: bool = False) -> int: + """ + Get total state dimension after processing. + + Args: + embodiment_tag: Embodiment identifier + include_sincos_expansion: If True, accounts for sin/cos encoding doubling dimensions + + Returns: + Total state dimension across all joint groups + """ + total_dim = 0 + state_config = self.modality_configs[embodiment_tag]["state"] + + # Get sin/cos embedding keys if enabled + sin_cos_keys = set() + if self.apply_sincos_state_encoding and hasattr(state_config, "sin_cos_embedding_keys"): + sin_cos_keys = set(state_config.sin_cos_embedding_keys) + + for joint_group in state_config.modality_keys: + base_dim = self.norm_params[embodiment_tag]["state"][joint_group]["dim"].item() + + # Sin/cos encoding doubles the dimension + if include_sincos_expansion and joint_group in sin_cos_keys: + total_dim += base_dim * 2 + else: + total_dim += base_dim + + return total_dim + + def get_action_dim(self, embodiment_tag: str) -> int: + """ + Get total action dimension. + + Args: + embodiment_tag: Embodiment identifier + + Returns: + Total action dimension across all joint groups + """ + total_dim = 0 + for joint_group in self.modality_configs[embodiment_tag]["action"].modality_keys: + total_dim += self.norm_params[embodiment_tag]["action"][joint_group]["dim"].item() + return total_dim + + def _convert_to_relative_action( + self, + action: np.ndarray, + reference_state: np.ndarray, + action_type: ActionType, + action_format: ActionFormat, + ) -> np.ndarray: + """Convert absolute action to relative action using reference state.""" + assert action.ndim == 2, f"Expected action shape (T, D), got {action.shape}" + assert reference_state.ndim == 1, f"Expected state shape (D,), got {reference_state.shape}" + + if action_type == ActionType.EEF: + action_chunking = EndEffectorActionChunk.from_array(action, action_format) + reference_frame = EndEffectorPose.from_action_format(reference_state, action_format) + + elif action_type == ActionType.NON_EEF: + action_chunking = JointActionChunk([JointPose(m) for m in action]) + reference_frame = JointPose(reference_state) + + else: + raise ValueError(f"Unknown ActionType: {action_type}") + + relative_action_chunking = action_chunking.relative_chunking(reference_frame=reference_frame) + return relative_action_chunking.to(action_format) + + def _convert_to_absolute_action( + self, + action: np.ndarray, + reference_state: np.ndarray, + action_type: ActionType, + action_format: ActionFormat, + ) -> np.ndarray: + """Convert relative action to absolute action using reference state.""" + assert action.ndim == 2, f"Expected action shape (T, D), got {action.shape}" + assert reference_state.ndim == 1, f"Expected state shape (D,), got {reference_state.shape}" + assert reference_state.shape[0] == action.shape[1], ( + f"State dim {reference_state.shape[0]} != action dim {action.shape[1]}" + ) + + if action_type == ActionType.EEF: + rel_action = EndEffectorActionChunk.from_array(action, action_format) + reference_frame = EndEffectorPose.from_action_format(reference_state, action_format) + + elif action_type == ActionType.NON_EEF: + rel_action = JointActionChunk([JointPose(pose) for pose in action]) + reference_frame = JointPose(reference_state) + + else: + raise ValueError(f"Unknown ActionType: {action_type}") + + abs_action = rel_action.to_absolute_chunking(reference_frame=reference_frame) + return abs_action.to(action_format) + + def __str__(self) -> str: + return ( + "StateActionProcessor(" + f"modality_configs={self.modality_configs}, " + f"statistics={self.statistics}, " + f"use_percentiles={self.use_percentiles}, " + f"clip_outliers={self.clip_outliers}, " + f"apply_sincos_state_encoding={self.apply_sincos_state_encoding}, " + f"use_relative_action={self.use_relative_action})" + ) diff --git a/vllm_omni/diffusion/models/gr00t/dataio/types.py b/vllm_omni/diffusion/models/gr00t/dataio/types.py new file mode 100644 index 00000000000..d4948a1c234 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/types.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +import numpy as np + +from vllm_omni.diffusion.models.gr00t.dataio.embodiment_tags import EmbodimentTag + + +class MessageType(Enum): + START_OF_EPISODE = "start_of_episode" + END_OF_EPISODE = "end_of_episode" + EPISODE_STEP = "episode_step" + IMAGE = "image" + TEXT = "text" + + +class ActionRepresentation(Enum): + RELATIVE = "relative" + DELTA = "delta" + ABSOLUTE = "absolute" + + +class ActionType(Enum): + EEF = "eef" + NON_EEF = "non_eef" + + +class ActionFormat(Enum): + DEFAULT = "default" + XYZ_ROT6D = "xyz+rot6d" + XYZ_ROTVEC = "xyz+rotvec" + + +@dataclass +class VLAStepData: + """ + Represents a single step of VLA (Vision-Language-Action) data. + + This is the core data structure returned by datasets, containing raw observation + and action data that will be processed by the SequenceVLAProcessor. + """ + + # Core data + images: dict[str, list[np.ndarray]] # view_name -> list[np.ndarray] (for temporal stacking) + states: dict[str, np.ndarray] # state_name -> np.ndarray (dim,) for single step or (horizon, dim) for trajectory + actions: dict[str, np.ndarray] # action_name -> np.ndarray (horizon, dim) for action chunk + masks: dict[str, list[np.ndarray]] | None = None # view_name -> list[np.ndarray] (H, W) + text: str | None = None # Optional task description or instruction + embodiment: EmbodimentTag = EmbodimentTag.NEW_EMBODIMENT # Optional embodiment tag for cross-embodiment training + is_demonstration: bool = ( + False # Whether the step is a demonstration. If True, no loss should be computed for this step. + ) + + # Flexible metadata that can be extended by users + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ActionConfig: + rep: ActionRepresentation + type: ActionType + format: ActionFormat + state_key: str | None = None + + +@dataclass +class ModalityConfig: + """Configuration for a modality defining how data should be sampled and loaded. + + This class specifies which indices to sample relative to a base index and which + keys to load for a particular modality (e.g., video, state, action). + """ + + delta_indices: list[int] + """Delta indices to sample relative to the current index. The returned data will + correspond to the original data at a sampled base index + delta indices.""" + modality_keys: list[str] + """The keys to load for the modality in the dataset.""" + sin_cos_embedding_keys: list[str] | None = None + """Optional list of keys to apply sin/cos encoding. If None or empty, use + min/max normalization for all keys.""" + mean_std_embedding_keys: list[str] | None = None + """Optional list of keys to apply mean/std normalization. If None or empty, + use min/max normalization for all keys.""" + action_configs: list[ActionConfig] | None = None + + def __post_init__(self): + """Validate fields and set default values.""" + if self.delta_indices is None or not isinstance(self.delta_indices, list): + raise ValueError(f"delta_indices must be a non-None list, got {self.delta_indices!r}") + if self.modality_keys is None or not isinstance(self.modality_keys, list) or len(self.modality_keys) == 0: + raise ValueError(f"modality_keys must be a non-empty list, got {self.modality_keys!r}") + if self.action_configs is not None: + assert len(self.action_configs) == len(self.modality_keys), ( + f"Number of action configs ({len(self.action_configs)}) must match " + f"number of modality keys ({len(self.modality_keys)})" + ) + parsed_action_configs = [] + for action_config in self.action_configs: + if isinstance(action_config, dict): + action_config = ActionConfig( + rep=ActionRepresentation[action_config["rep"]], + type=ActionType[action_config["type"]], + format=ActionFormat[action_config["format"]], + state_key=action_config.get("state_key", None), + ) + parsed_action_configs.append(action_config) + self.action_configs = parsed_action_configs diff --git a/vllm_omni/diffusion/models/gr00t/dataio/utils.py b/vllm_omni/diffusion/models/gr00t/dataio/utils.py new file mode 100644 index 00000000000..8096d62396c --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/utils.py @@ -0,0 +1,291 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import asdict, is_dataclass +from enum import Enum +from typing import Any + +import numpy as np + +from vllm_omni.diffusion.models.gr00t.configs.embodiment_configs import ModalityConfig + + +def apply_sin_cos_encoding(values: np.ndarray) -> np.ndarray: + """Apply sin/cos encoding to values. + + Args: + values: Array of shape (..., D) containing values to encode + + Returns: + Array of shape (..., 2*D) with [sin, cos] concatenated + + Note: This DOUBLES the dimension. For example: + Input: [v₁, v₂, v₃] with shape (..., 3) + Output: [sin(v₁), sin(v₂), sin(v₃), cos(v₁), cos(v₂), cos(v₃)] with shape (..., 6) + """ + sin_values = np.sin(values) + cos_values = np.cos(values) + # Concatenate sin and cos: [sin(v1), sin(v2), ..., cos(v1), cos(v2), ...] + return np.concatenate([sin_values, cos_values], axis=-1) + + +def nested_dict_to_numpy(data): + """ + Recursively converts bottom-level list of lists to NumPy arrays. + + Args: + data: A nested dictionary where bottom nodes are list of lists, + and parent nodes are strings (keys) + + Returns: + The same dictionary structure with bottom-level lists converted to NumPy arrays + + Example: + >>> data = {"a": {"b": [[0, 1], [2, 3]]}} + >>> result = nested_dict_to_numpy(data) + >>> print(result["a"]["b"]) + [[0 1] + [2 3]] + """ + if isinstance(data, dict): + return {key: nested_dict_to_numpy(value) for key, value in data.items()} + elif isinstance(data, list): + # Convert lists to numpy arrays + # NumPy will handle both 1D and 2D cases appropriately + return np.array(data) + else: + return data + + +def normalize_values_minmax(values, params): + """ + Normalize values using min-max normalization to [-1, 1] range. + + Args: + values: Input values to normalize + - Shape: (T, D) or (B, T, D) where B is batch, T is time/step, D is feature dimension + - Can handle 2D or 3D arrays where last axis represents features + params: Dictionary with "min" and "max" keys + - params["min"]: Minimum values for normalization + * Case 1 - 1D bounds: Shape (D,) - same min/max for all steps + * Case 2 - 2D bounds: Shape (T, D) - different min/max per step + - params["max"]: Maximum values for normalization + * Case 1 - 1D bounds: Shape (D,) - same min/max for all steps + * Case 2 - 2D bounds: Shape (T, D) - different min/max per step + joint_group: Optional indexing for joint groups (legacy parameter) + + Returns: + Normalized values in [-1, 1] range + - Same shape as input values: (T, D) or (B, T, D) + - Values are linearly mapped from [min, max] to [-1, 1] + - For features where min == max, normalized value is 0 + + Examples: + # 1D bounds - same normalization for all steps + values: (10, 5), params["min"]: (5,), params["max"]: (5,) + + # 2D bounds - per-step normalization + values: (8, 4), params["min"]: (8, 4), params["max"]: (8, 4) + """ + min_vals = params["min"] + max_vals = params["max"] + normalized = np.zeros_like(values) + + mask = ~np.isclose(max_vals, min_vals) + + normalized[..., mask] = (values[..., mask] - min_vals[..., mask]) / (max_vals[..., mask] - min_vals[..., mask]) + normalized[..., mask] = 2 * normalized[..., mask] - 1 + + return normalized + + +def unnormalize_values_minmax(normalized_values, params): + """ + Min-max unnormalization from [-1, 1] range back to original range. + + Args: + normalized_values: Normalized input values in [-1, 1] range + - Shape: (T, D) or (B, T, D) where B is batch, T is time/step, D is feature dimension + - Values outside [-1, 1] are automatically clipped + params: Dictionary with "min" and "max" keys + - params["min"]: Original minimum values used for normalization + * Case 1 - 1D bounds: Shape (D,) - same min/max for all steps + * Case 2 - 2D bounds: Shape (T, D) - different min/max per step + - params["max"]: Original maximum values used for normalization + * Case 1 - 1D bounds: Shape (D,) - same min/max for all steps + * Case 2 - 2D bounds: Shape (T, D) - different min/max per step + + Returns: + Unnormalized values in original range [min, max] + - Same shape as input normalized_values: (T, D) or (B, T, D) + - Values are linearly mapped from [-1, 1] back to [min, max] + - Input values are clipped to [-1, 1] before unnormalization + + Examples: + # 1D bounds - same unnormalization for all steps + normalized_values: (10, 5), params["min"]: (5,), params["max"]: (5,) + + # 2D bounds - per-step unnormalization + normalized_values: (8, 4), params["min"]: (8, 4), params["max"]: (8, 4) + """ + + min_vals = params["min"] + max_vals = params["max"] + range_vals = max_vals - min_vals + + # Unnormalize from [-1, 1] + unnormalized = (np.clip(normalized_values, -1.0, 1.0) + 1.0) / 2.0 * range_vals + min_vals + return unnormalized + + +def normalize_values_meanstd(values, params): + """ + Normalize values using mean-std (z-score) normalization. + + Args: + values: Input values to normalize + - Shape: (T, D) or (B, T, D) where B is batch, T is time/step, D is feature dimension + - Can handle 2D or 3D arrays where last axis represents features + params: Dictionary with "mean" and "std" keys + - params["mean"]: Mean values for normalization + * Case 1 - 1D params: Shape (D,) - same mean for all steps + * Case 2 - 2D params: Shape (T, D) - different mean per step + - params["std"]: Standard deviation values for normalization + * Case 1 - 1D params: Shape (D,) - same std for all steps + * Case 2 - 2D params: Shape (T, D) - different std per step + + Returns: + Normalized values using z-score normalization + - Same shape as input values: (T, D) or (B, T, D) + - Values are transformed as: (x - mean) / std + - For features where std == 0, normalized value equals original value + + Examples: + # 1D params - same normalization for all steps + values: (10, 5), params["mean"]: (5,), params["std"]: (5,) + + # 2D params - per-step normalization + values: (8, 4), params["mean"]: (8, 4), params["std"]: (8, 4) + """ + mean_vals = params["mean"] + std_vals = params["std"] + + # Create mask for non-zero standard deviations + mask = std_vals != 0 + + # Initialize normalized array + normalized = np.zeros_like(values) + + # Normalize only features with non-zero std + normalized[..., mask] = (values[..., mask] - mean_vals[..., mask]) / std_vals[..., mask] + + # Keep original values for zero-std features + normalized[..., ~mask] = values[..., ~mask] + + return normalized + + +def unnormalize_values_meanstd(normalized_values, params): + """ + Mean-std unnormalization (reverse z-score normalization). + + Args: + normalized_values: Normalized input values (z-scores) + - Shape: (T, D) or (B, T, D) where B is batch, T is time/step, D is feature dimension + - Can handle 2D or 3D arrays where last axis represents features + params: Dictionary with "mean" and "std" keys + - params["mean"]: Original mean values used for normalization + * Case 1 - 1D params: Shape (D,) - same mean for all steps + * Case 2 - 2D params: Shape (T, D) - different mean per step + - params["std"]: Original standard deviation values used for normalization + * Case 1 - 1D params: Shape (D,) - same std for all steps + * Case 2 - 2D params: Shape (T, D) - different std per step + + Returns: + Unnormalized values in original scale + - Same shape as input normalized_values: (T, D) or (B, T, D) + - Values are transformed as: x * std + mean + - For features where std == 0, unnormalized value equals normalized value + + Examples: + # 1D params - same unnormalization for all steps + normalized_values: (10, 5), params["mean"]: (5,), params["std"]: (5,) + + # 2D params - per-step unnormalization + normalized_values: (8, 4), params["mean"]: (8, 4), params["std"]: (8, 4) + """ + mean_vals = params["mean"] + std_vals = params["std"] + + # Create mask for non-zero standard deviations + mask = std_vals != 0 + + # Initialize unnormalized array + unnormalized = np.zeros_like(normalized_values) + + # Unnormalize only features with non-zero std + unnormalized[..., mask] = normalized_values[..., mask] * std_vals[..., mask] + mean_vals[..., mask] + + # Keep normalized values for zero-std features + unnormalized[..., ~mask] = normalized_values[..., ~mask] + + return unnormalized + + +def to_json_serializable(obj: Any) -> Any: + """ + Recursively convert dataclasses and numpy arrays to JSON-serializable format. + + Args: + obj: Object to convert (can be dataclass, numpy array, dict, list, etc.) + + Returns: + JSON-serializable representation of the object + """ + if is_dataclass(obj) and not isinstance(obj, type): + # Convert dataclass to dict, then recursively process the dict + return to_json_serializable(asdict(obj)) + elif isinstance(obj, np.ndarray): + # Convert numpy array to list + return obj.tolist() + elif isinstance(obj, np.integer): + # Convert numpy integers to Python int + return int(obj) + elif isinstance(obj, np.floating): + # Convert numpy floats to Python float + return float(obj) + elif isinstance(obj, np.bool_): + # Convert numpy bool to Python bool + return bool(obj) + elif isinstance(obj, dict): + # Recursively process dictionary values + return {key: to_json_serializable(value) for key, value in obj.items()} + elif isinstance(obj, (list, tuple)): + # Recursively process list/tuple elements + return [to_json_serializable(item) for item in obj] + elif isinstance(obj, set): + # Convert set to list + return [to_json_serializable(item) for item in obj] + elif isinstance(obj, (str, int, float, bool, type(None))): + # Already JSON-serializable + return obj + elif isinstance(obj, Enum): + return obj.name + else: + # For other types, try to convert to string as fallback + # You might want to handle specific types differently + return str(obj) + + +def parse_modality_configs( + modality_configs: dict[str, dict[str, ModalityConfig]], +) -> dict[str, dict[str, ModalityConfig]]: + parsed_modality_configs = {} + for embodiment_tag, modality_config in modality_configs.items(): + parsed_modality_configs[embodiment_tag] = {} + for modality, config in modality_config.items(): + if isinstance(config, dict): + parsed_modality_configs[embodiment_tag][modality] = ModalityConfig(**config) + else: + parsed_modality_configs[embodiment_tag][modality] = config + return parsed_modality_configs diff --git a/vllm_omni/diffusion/models/gr00t/modeling/__init__.py b/vllm_omni/diffusion/models/gr00t/modeling/__init__.py new file mode 100644 index 00000000000..7aecb405ff9 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/modeling/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.diffusion.models.gr00t.modeling.gr00t_n1d7 import Gr00tN1d7 +from vllm_omni.diffusion.models.gr00t.modeling.processing_gr00t_n1d7 import Gr00tN1d7Processor + +__all__ = ["Gr00tN1d7", "Gr00tN1d7Processor"] diff --git a/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py new file mode 100644 index 00000000000..4799b7cf615 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py @@ -0,0 +1,479 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import logging +import os +from typing import Any + +import torch +from torch import nn +from transformers import AutoConfig, AutoModel, PreTrainedModel +from transformers.feature_extraction_utils import BatchFeature +from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLForConditionalGeneration as _Qwen3VLForConditionalGeneration, +) + +from vllm_omni.diffusion.models.gr00t.configs.gr00t_n1d7 import Gr00tN1d7Config +from vllm_omni.diffusion.models.gr00t.modeling.modules.dit import AlternateVLDiT, DiT, SelfAttentionTransformer +from vllm_omni.diffusion.models.gr00t.modeling.modules.embodiment_conditioned_mlp import ( + CategorySpecificMLP, + MultiEmbodimentActionEncoder, +) +from vllm_omni.diffusion.models.gr00t.modeling.processing_gr00t_n1d7 import Gr00tN1d7DataCollator + +logger = logging.getLogger(__name__) + + +class Gr00tN1d7ActionHead(nn.Module): + """Action head component for flow matching diffusion policy.""" + + supports_gradient_checkpointing = True + + def __init__(self, config: Gr00tN1d7Config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.input_embedding_dim = config.input_embedding_dim + + if config.use_alternate_vl_dit: + self.model = AlternateVLDiT( + **config.diffusion_model_cfg, + cross_attention_dim=config.backbone_embedding_dim, + attend_text_every_n_blocks=config.attend_text_every_n_blocks, + ) + logger.info("Using AlternateVLDiT for diffusion model") + else: + self.model = DiT( + **config.diffusion_model_cfg, + cross_attention_dim=config.backbone_embedding_dim, + ) + logger.info("Using DiT for diffusion model") + self.action_dim = config.max_action_dim + self.action_horizon = config.action_horizon + self.num_inference_timesteps = config.num_inference_timesteps + + self.state_encoder = CategorySpecificMLP( + num_categories=config.max_num_embodiments, + input_dim=config.max_state_dim * config.state_history_length, + hidden_dim=self.hidden_size, + output_dim=self.input_embedding_dim, + ) + self.action_encoder = MultiEmbodimentActionEncoder( + action_dim=self.action_dim, + hidden_size=self.input_embedding_dim, + num_embodiments=config.max_num_embodiments, + ) + self.action_decoder = CategorySpecificMLP( + num_categories=config.max_num_embodiments, + input_dim=self.hidden_size, + hidden_dim=self.hidden_size, + output_dim=self.action_dim, + ) + + self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity() + + vl_self_attention_cfg = config.vl_self_attention_cfg + if vl_self_attention_cfg and vl_self_attention_cfg.get("num_layers", 0) > 0: + self.vl_self_attention = SelfAttentionTransformer(**vl_self_attention_cfg) + else: + self.vl_self_attention = nn.Identity() + + if config.add_pos_embed: + self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim) + nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) + + self.num_timestep_buckets = config.num_timestep_buckets + + def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature: + backbone_features = backbone_output["backbone_features"] + backbone_features = self.vlln(backbone_features) + backbone_features = self.vl_self_attention(backbone_features) + backbone_output["backbone_features"] = backbone_features + return backbone_output + + def _encode_features(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature: + """ + Encode features for the action head. + + Args: + backbone_output: Output from the backbone model containing: + - backbone_features: [B, seq_len, backbone_embedding_dim] + - backbone_attention_mask: [B, seq_len] + action_input: Input containing: + - state: [B, state_history_length, max_state_dim] + - embodiment_id: [B] (embodiment IDs) + + Returns: + BatchFeature containing: + - backbone_features: [B, seq_len, backbone_embedding_dim] + - state_features: [B, 1, input_embedding_dim] + """ + backbone_output = self.process_backbone_output(backbone_output) + + vl_embeds = backbone_output.backbone_features + embodiment_id = action_input.embodiment_id + + state = action_input.state + current_T = state.shape[1] + assert current_T == self.config.state_history_length, "current_T != state_history_length" + # [B, state_history_length, max_state_dim] -> [B, 1, state_history_length * max_state_dim] + state = state.view(state.shape[0], 1, -1) + + state_features = self.state_encoder(state, embodiment_id) + + return BatchFeature(data={"backbone_features": vl_embeds, "state_features": state_features}) + + @torch.no_grad() + def get_action_with_features( + self, + backbone_features: torch.Tensor, + state_features: torch.Tensor, + embodiment_id: torch.Tensor, + backbone_output: BatchFeature, + action_input: BatchFeature, + options: dict[str, Any] | None = None, + ) -> BatchFeature: + """ + Generate actions using the flow matching diffusion process. + + Args: + backbone_features: [B, seq_len, backbone_embedding_dim] + state_features: [B, state_horizon, input_embedding_dim] + embodiment_id: [B] (embodiment IDs) + backbone_output: Output from the backbone model + """ + vl_embeds = backbone_features + + batch_size = vl_embeds.shape[0] + device = vl_embeds.device + _seed_env = os.environ.get("GR00T_NOISE_SEED") + if _seed_env is not None: + _gen = torch.Generator(device=device).manual_seed(int(_seed_env)) + actions = torch.randn( + size=(batch_size, self.config.action_horizon, self.action_dim), + dtype=vl_embeds.dtype, + device=device, + generator=_gen, + ) + else: + actions = torch.randn( + size=(batch_size, self.config.action_horizon, self.action_dim), + dtype=vl_embeds.dtype, + device=device, + ) + + dt = 1.0 / self.num_inference_timesteps + vel_strength = torch.ones_like(actions) + + if "action" in action_input: + # If action in input when doing get action, it means we want to use RTC. + # action_horizon is the action horizon of the input action. + # rtc_overlap_steps is the number of steps to overlap with the previous action chunks. + # rtc_frozen_steps is the policy inference latency expressed as frozen action steps. + # rtc_ramp_rate is the rate of the ramp of denoising the actions. + assert options is not None, "options is not None" + assert "action_horizon" in options, "action_horizon is not in options" + assert "rtc_overlap_steps" in options, "rtc_overlap_steps is not in options" + assert "rtc_frozen_steps" in options, "rtc_frozen_steps is not in options" + assert "rtc_ramp_rate" in options, "rtc_ramp_rate is not in options" + + action_horizon_before_padding = options["action_horizon"] + + # Use previous action instead of pure noise to do inpainting + actions[:, : options["rtc_overlap_steps"], :] = action_input["action"][ + :, + action_horizon_before_padding - options["rtc_overlap_steps"] : action_horizon_before_padding, + :, + ] + vel_strength[:, : options["rtc_frozen_steps"], :] = 0.0 + # NOTE: use an exponential ramp strength to set the remaining unfrozen rtc_steps + intermediate_steps = options["rtc_overlap_steps"] - options["rtc_frozen_steps"] + # Create exponential ramp from 0 to 1 over intermediate steps + t = torch.linspace(0.0, 1.0, intermediate_steps + 2, device=device) + ramp = 1 - torch.exp(-options["rtc_ramp_rate"] * t) + ramp = ramp / ramp[-1].clamp_min(1e-8) # normalize to [0,1] + ramp = ramp[1:-1] # we will only take the middle part of the ramp, ignore the 0.0 and 1.0 + # Apply ramp to the intermediate steps [batch, intermediate_steps, action_dim] + vel_strength[ + :, + options["rtc_frozen_steps"] : options["rtc_overlap_steps"], + :, + ] = ramp[None, :, None].to(device) + + # Run denoising steps. + for t in range(self.num_inference_timesteps): + t_discretized = t * self.num_timestep_buckets // self.num_inference_timesteps + + # Embed noised action trajectory. + timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device) + action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id) + # Add position embedding. + if self.config.add_pos_embed: + pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) + pos_embs = self.position_embedding(pos_ids).unsqueeze(0) + action_features = action_features + pos_embs + + # Join vision, language, state and action embedding along sequence dimension. + sa_embs = torch.cat((state_features, action_features), dim=1) + + # Run model forward. + if self.config.use_alternate_vl_dit: + model_output = self.model( + hidden_states=sa_embs, + encoder_hidden_states=vl_embeds, + timestep=timesteps_tensor, + image_mask=backbone_output.image_mask, + backbone_attention_mask=backbone_output.backbone_attention_mask, + ) + else: + model_output = self.model( + hidden_states=sa_embs, + encoder_hidden_states=vl_embeds, + timestep=timesteps_tensor, + ) + pred = self.action_decoder(model_output, embodiment_id) + + pred_velocity = pred[:, -self.action_horizon :] + + # Update actions using euler integration. + actions = actions + dt * pred_velocity * vel_strength + + return BatchFeature( + data={ + "action_pred": actions, + "backbone_features": vl_embeds, + "state_features": state_features, + } + ) + + @torch.no_grad() + def get_action( + self, + backbone_output: BatchFeature, + action_input: BatchFeature, + options: dict[str, Any] | None = None, + ) -> BatchFeature: + """ + Generate actions using the flow matching diffusion process. + + Args: + backbone_output: Output from the backbone model containing: + - backbone_features: [B, seq_len, backbone_embedding_dim] + - backbone_attention_mask: [B, seq_len] + action_input: Input containing: + - state: [B, state_dim] + - embodiment_id: [B] (embodiment IDs) + + Returns: + BatchFeature containing: + - action_pred: [B, action_horizon, action_dim] predicted actions + """ + features = self._encode_features(backbone_output, action_input) + return self.get_action_with_features( + backbone_features=features.backbone_features, + state_features=features.state_features, + embodiment_id=action_input.embodiment_id, + backbone_output=backbone_output, + action_input=action_input, + options=options, + ) + + @property + def device(self): + return next(iter(self.parameters())).device + + @property + def dtype(self): + return next(iter(self.parameters())).dtype + + def prepare_input(self, batch: dict) -> BatchFeature: + """Prepare input batch for the action head.""" + return BatchFeature(data=batch) + + +class _Qwen3VLBackbone(nn.Module): + """GR00T adapter around the shared Qwen3-VL implementation.""" + + def __init__( + self, + model_name: str, + select_layer: int, + backbone_embedding_dim: int, + load_bf16: bool, + transformers_loading_kwargs: dict[str, Any] | None = None, + ): + super().__init__() + backbone_config = AutoConfig.from_pretrained( + model_name, **(transformers_loading_kwargs or {"trust_remote_code": True}) + ) + backbone_config.text_config._attn_implementation = "sdpa" + backbone_config.vision_config._attn_implementation = "sdpa" + self.model = _Qwen3VLForConditionalGeneration(backbone_config).eval() + if load_bf16: + self.model.to(dtype=torch.bfloat16) + + target_layers = select_layer if select_layer >= 0 else len(self.model.model.language_model.layers) + while len(self.model.model.language_model.layers) > target_layers: + self.model.model.language_model.layers.pop(-1) + + def set_frozen_modules_to_eval_mode(self) -> None: + self.eval() + + def prepare_input(self, batch: dict) -> BatchFeature: + return BatchFeature(data=batch) + + def forward(self, vl_input: BatchFeature) -> BatchFeature: + self.set_frozen_modules_to_eval_mode() + keys_to_use = [ + "input_ids", + "attention_mask", + "pixel_values", + "image_grid_thw", + "mm_token_type_ids", + ] + vl_input = {key: vl_input[key] for key in keys_to_use if key in vl_input} + # GR00T was trained against the **pre-final-RMSNorm** output of the last + # decoder layer (this was `outer.hidden_states[-1]` under HF 4.57.x semantics, + # where the @check_model_inputs tie-logic skipped overwriting because + # Qwen3VLCausalLMOutputWithPast has no `last_hidden_state` field). + # HF >=5.x changed the output_hidden_states tuple plumbing so that + # `hidden_states[-1]` is post-norm, which is a different tensor and breaks + # the trained action head. Capture the pre-norm tensor via a hook on the + # final RMSNorm regardless of transformers version. + _captured: list[torch.Tensor] = [] + + def _pre_norm_hook(_module, args, _out): + _captured.append(args[0]) + + norm = self.model.model.language_model.norm + handle = norm.register_forward_hook(_pre_norm_hook) + try: + self.model.model(**vl_input, return_dict=True) + finally: + handle.remove() + if not _captured: + raise RuntimeError("Failed to capture pre-norm hidden states from Qwen3-VL backbone") + backbone_features = _captured[-1] + image_mask = vl_input["input_ids"] == self.model.config.image_token_id + attention_mask = vl_input["attention_mask"] == 1 + return BatchFeature( + data={ + "backbone_features": backbone_features, + "backbone_attention_mask": attention_mask, + "image_mask": image_mask, + } + ) + + +def get_backbone_cls(config: Gr00tN1d7Config): + if "nvidia/Cosmos-Reason2" in config.model_name: + return _Qwen3VLBackbone + raise ValueError(f"Unsupported model name: {config.model_name}") + + +class Gr00tN1d7(PreTrainedModel): + """Gr00tN1d7: VLA model with Cosmos-Reason2-2B (Qwen3-VL) backbone.""" + + config_class = Gr00tN1d7Config + supports_gradient_checkpointing = True + _tp_plan = {} + + @property + def all_tied_weights_keys(self) -> dict[str, Any]: + return {} + + def __init__( + self, + config: Gr00tN1d7Config, + transformers_loading_kwargs: dict | None = None, + ): + super().__init__(config) + if transformers_loading_kwargs is None: + transformers_loading_kwargs = {"trust_remote_code": True} + self.config = config + + backbone_cls = get_backbone_cls(config) + self.backbone = backbone_cls( + model_name=config.model_name, + select_layer=config.select_layer, + backbone_embedding_dim=config.backbone_embedding_dim, + load_bf16=config.load_bf16, + transformers_loading_kwargs=transformers_loading_kwargs, + ) + + # Initialize action head + self.action_head = Gr00tN1d7ActionHead(config) + + self.collator = Gr00tN1d7DataCollator( + model_name=config.model_name, + model_type=config.backbone_model_type, + transformers_loading_kwargs=transformers_loading_kwargs, + ) + + def prepare_input(self, inputs: dict) -> tuple[BatchFeature, BatchFeature]: + """Prepare inputs for backbone and action head.""" + + # NOTE -- currently the eval code doesn't use collator, so we need to add it here + if "vlm_content" in inputs: + # Fix for n_envs > 1: Process all environments' VLM content, not just the first + vlm_content_list = inputs["vlm_content"] + # Ensure vlm_content_list is always a list for consistent processing + if not isinstance(vlm_content_list, list): + vlm_content_list = [vlm_content_list] + + # Process all VLM contents through the collator + prep = self.collator([{"vlm_content": vlm} for vlm in vlm_content_list])["inputs"] + inputs.pop("vlm_content") + inputs.update(prep) + + backbone_inputs = self.backbone.prepare_input(inputs) + action_inputs = self.action_head.prepare_input(inputs) + + backbone_inputs = backbone_inputs.to(device=self.device, dtype=self.dtype) + action_inputs = action_inputs.to(device=self.device, dtype=self.dtype) + + return backbone_inputs, action_inputs + + def forward(self, inputs: dict) -> BatchFeature: + """ + Forward pass through the complete model. + + Args: + inputs: Dictionary containing: + - Action inputs (state, action, embodiment_id, etc.) + + Returns: + BatchFeature containing loss and other outputs + """ + # Prepare inputs for backbone and action head + backbone_inputs, action_inputs = self.prepare_input(inputs) + backbone_outputs = self.backbone(backbone_inputs) + action_outputs = self.action_head(backbone_outputs, action_inputs) + + return action_outputs + + def get_action(self, inputs: dict, options: dict[str, Any] | None = None) -> BatchFeature: + """ + Generate actions using the complete model. + """ + # Prepare inputs for backbone and action head + backbone_inputs, action_inputs = self.prepare_input(inputs) + + # Forward through backbone + backbone_outputs = self.backbone(backbone_inputs) + action_outputs = self.action_head.get_action(backbone_outputs, action_inputs, options) + + return action_outputs + + @property + def device(self): + return next(iter(self.parameters())).device + + @property + def dtype(self): + return next(iter(self.parameters())).dtype + + +# Register the model with HuggingFace +AutoConfig.register("Gr00tN1d7", Gr00tN1d7Config) +AutoModel.register(Gr00tN1d7Config, Gr00tN1d7) diff --git a/vllm_omni/diffusion/models/gr00t/modeling/modules/__init__.py b/vllm_omni/diffusion/models/gr00t/modeling/modules/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/modeling/modules/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py b/vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py new file mode 100644 index 00000000000..54fbe770d0d --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py @@ -0,0 +1,344 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.nn.functional as F +from diffusers import ConfigMixin, ModelMixin +from diffusers.configuration_utils import register_to_config +from diffusers.models.attention import Attention, FeedForward +from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps +from torch import nn + + +class AdaLayerNorm(nn.Module): + def __init__( + self, + embedding_dim: int, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + chunk_dim: int = 0, + ): + super().__init__() + self.chunk_dim = chunk_dim + output_dim = embedding_dim * 2 + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) + + def forward( + self, + x: torch.Tensor, + temb: torch.Tensor | None = None, + ) -> torch.Tensor: + temb = self.linear(self.silu(temb)) + scale, shift = temb.chunk(2, dim=1) + x = self.norm(x) * (1 + scale[:, None]) + shift[:, None] + return x + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: int | None = None, + activation_fn: str = "geglu", + attention_bias: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: str | None = None, + num_positional_embeddings: int | None = None, + ff_inner_dim: int | None = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.norm_type = norm_type + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_position_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + self.final_dropout = nn.Dropout(dropout) if final_dropout else None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + temb: torch.LongTensor | None = None, + ) -> torch.Tensor: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, temb) + else: + norm_hidden_states = self.norm1(hidden_states) + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=(encoder_attention_mask if encoder_hidden_states is not None else attention_mask), + ) + if self.final_dropout: + attn_output = self.final_dropout(attn_output) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + norm_hidden_states = self.norm3(hidden_states) + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + return hidden_states + + +class TimestepEncoder(nn.Module): + def __init__(self, embedding_dim: int): + super().__init__() + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + dtype = next(self.parameters()).dtype + return self.timestep_embedder(self.time_proj(timesteps).to(dtype)) + + +class DiT(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 8, + attention_head_dim: int = 64, + output_dim: int = 26, + num_layers: int = 12, + dropout: float = 0.1, + attention_bias: bool = True, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: int | None = 1000, + upcast_attention: bool = False, + norm_type: str = "ada_norm", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + max_num_positional_embeddings: int = 512, + compute_dtype=torch.float32, + final_dropout: bool = True, + positional_embeddings: str | None = "sinusoidal", + interleave_self_attention=False, + cross_attention_dim: int | None = None, + ): + super().__init__() + + self.inner_dim = num_attention_heads * attention_head_dim + self.gradient_checkpointing = False + self.timestep_encoder = TimestepEncoder(embedding_dim=self.inner_dim) + + all_blocks = [] + for idx in range(num_layers): + use_self_attn = idx % 2 == 1 and interleave_self_attention + curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None + all_blocks.append( + BasicTransformerBlock( + self.inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + activation_fn=activation_fn, + attention_bias=attention_bias, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + positional_embeddings=positional_embeddings, + num_positional_embeddings=max_num_positional_embeddings, + final_dropout=final_dropout, + cross_attention_dim=curr_cross_attention_dim, + ) + ) + self.transformer_blocks = nn.ModuleList(all_blocks) + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) + self.proj_out_2 = nn.Linear(self.inner_dim, output_dim) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + return_all_hidden_states: bool = False, + ): + temb = self.timestep_encoder(timestep) + hidden_states = hidden_states.contiguous() + encoder_hidden_states = encoder_hidden_states.contiguous() + all_hidden_states = [hidden_states] + + for idx, block in enumerate(self.transformer_blocks): + if idx % 2 == 1 and self.config.interleave_self_attention: + hidden_states = block(hidden_states, temb=temb) + else: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + temb=temb, + ) + all_hidden_states.append(hidden_states) + + shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + if return_all_hidden_states: + return self.proj_out_2(hidden_states), all_hidden_states + return self.proj_out_2(hidden_states) + + +class AlternateVLDiT(DiT): + def __init__(self, *args, attend_text_every_n_blocks: int = 2, **kwargs): + super().__init__(*args, **kwargs) + self.attend_text_every_n_blocks = attend_text_every_n_blocks + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + return_all_hidden_states: bool = False, + image_mask: torch.Tensor | None = None, + backbone_attention_mask: torch.Tensor | None = None, + ): + assert image_mask is not None, "Image mask is required" + assert self.config.interleave_self_attention, "Interleave self attention must be enabled" + + temb = self.timestep_encoder(timestep) + hidden_states = hidden_states.contiguous() + encoder_hidden_states = encoder_hidden_states.contiguous() + + image_attention_mask = image_mask & backbone_attention_mask + non_image_attention_mask = (~image_mask) & backbone_attention_mask + + all_hidden_states = [hidden_states] + for idx, block in enumerate(self.transformer_blocks): + if idx % 2 == 1: + hidden_states = block(hidden_states, temb=temb) + else: + curr_encoder_attention_mask = ( + non_image_attention_mask + if idx % (2 * self.attend_text_every_n_blocks) == 0 + else image_attention_mask + ) + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=curr_encoder_attention_mask, + temb=temb, + ) + all_hidden_states.append(hidden_states) + + shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + if return_all_hidden_states: + return self.proj_out_2(hidden_states), all_hidden_states + return self.proj_out_2(hidden_states) + + +class SelfAttentionTransformer(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 8, + attention_head_dim: int = 64, + output_dim: int = 26, + num_layers: int = 12, + dropout: float = 0.1, + attention_bias: bool = True, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: int | None = 1000, + upcast_attention: bool = False, + max_num_positional_embeddings: int = 512, + compute_dtype=torch.float32, + final_dropout: bool = True, + positional_embeddings: str | None = "sinusoidal", + interleave_self_attention=False, + ): + super().__init__() + + self.inner_dim = num_attention_heads * attention_head_dim + self.gradient_checkpointing = False + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + activation_fn=activation_fn, + attention_bias=attention_bias, + upcast_attention=upcast_attention, + positional_embeddings=positional_embeddings, + num_positional_embeddings=max_num_positional_embeddings, + final_dropout=final_dropout, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + return_all_hidden_states: bool = False, + ): + hidden_states = hidden_states.contiguous() + all_hidden_states = [hidden_states] + for block in self.transformer_blocks: + hidden_states = block(hidden_states) + all_hidden_states.append(hidden_states) + if return_all_hidden_states: + return hidden_states, all_hidden_states + return hidden_states diff --git a/vllm_omni/diffusion/models/gr00t/modeling/modules/embodiment_conditioned_mlp.py b/vllm_omni/diffusion/models/gr00t/modeling/modules/embodiment_conditioned_mlp.py new file mode 100644 index 00000000000..17f0ed661f7 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/modeling/modules/embodiment_conditioned_mlp.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.nn.functional as F +from torch import nn + +from vllm_omni.diffusion.utils.flow_matching import SinusoidalPositionalEncoding, swish + + +class CategorySpecificLinear(nn.Module): + """Linear layer with category-specific weights and biases for multi-embodiment support.""" + + def __init__(self, num_categories, input_dim, hidden_dim): + super().__init__() + self.num_categories = num_categories + # For each category, we have separate weights and biases. + self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim)) + self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim)) + + def forward(self, x, cat_ids): + """ + Args: + x: [B, T, input_dim] input tensor + cat_ids: [B] category/embodiment IDs + Returns: + [B, T, hidden_dim] output tensor + """ + selected_W = self.W[cat_ids] + selected_b = self.b[cat_ids] + return torch.bmm(x, selected_W) + selected_b.unsqueeze(1) + + def expand_action_dimension(self, old_action_dim, new_action_dim, expand_input=False, expand_output=False): + """ + Safely expand action dimension with explicit targeting. + + Args: + old_action_dim: Original action dimension + new_action_dim: New (larger) action dimension + expand_input: Whether to expand input dimension (dim=1) + expand_output: Whether to expand output dimension (dim=2) + """ + if new_action_dim <= old_action_dim: + raise ValueError(f"New action dim {new_action_dim} must be larger than old action dim {old_action_dim}") + + # Expand input dimension (dim=1) only if explicitly requested AND dimensions match + if expand_input and self.W.shape[1] == old_action_dim: + repeat_times = new_action_dim // old_action_dim + remainder = new_action_dim % old_action_dim + + new_W_parts = [self.W] * repeat_times + if remainder > 0: + new_W_parts.append(self.W[:, :remainder, :]) + + new_W = torch.cat(new_W_parts, dim=1) + self.W = nn.Parameter(new_W) + + # Expand output dimension (dim=2) only if explicitly requested AND dimensions match + if expand_output and self.W.shape[2] == old_action_dim: + repeat_times = new_action_dim // old_action_dim + remainder = new_action_dim % old_action_dim + + new_W_parts = [self.W] * repeat_times + if remainder > 0: + new_W_parts.append(self.W[:, :, :remainder]) + + new_W = torch.cat(new_W_parts, dim=2) + self.W = nn.Parameter(new_W) + + # Expand bias for output dimension + if self.b.shape[1] == old_action_dim: + new_b_parts = [self.b] * repeat_times + if remainder > 0: + new_b_parts.append(self.b[:, :remainder]) + + new_b = torch.cat(new_b_parts, dim=1) + self.b = nn.Parameter(new_b) + + +class CategorySpecificMLP(nn.Module): + """Two-layer MLP with category-specific weights for multi-embodiment support.""" + + def __init__(self, num_categories, input_dim, hidden_dim, output_dim): + super().__init__() + self.num_categories = num_categories + self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim) + self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim) + + def forward(self, x, cat_ids): + """ + Args: + x: [B, T, input_dim] input tensor + cat_ids: [B] category/embodiment IDs + Returns: + [B, T, output_dim] output tensor + """ + hidden = F.relu(self.layer1(x, cat_ids)) + return self.layer2(hidden, cat_ids) + + def expand_action_dimension(self, old_action_dim, new_action_dim): + """ + Expand action dimension by copying weights from existing dimensions. + + Args: + old_action_dim: Original action dimension + new_action_dim: New (larger) action dimension + """ + # self.layer1 does not take action_dim as input, so no expansion needed + self.layer2.expand_action_dimension(old_action_dim, new_action_dim, expand_input=False, expand_output=True) + + +class MultiEmbodimentActionEncoder(nn.Module): + """Action encoder with multi-embodiment support and sinusoidal positional encoding.""" + + def __init__(self, action_dim, hidden_size, num_embodiments): + super().__init__() + self.hidden_size = hidden_size + self.num_embodiments = num_embodiments + + # W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w} + self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w) + self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w) + self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w) + self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) + + def forward(self, actions, timesteps, cat_ids): + """ + Args: + actions: [B, T, action_dim] action tensor + timesteps: [B,] timesteps - a single scalar per batch item + cat_ids: [B,] category/embodiment IDs + Returns: + [B, T, hidden_size] encoded action features + """ + B, T, _ = actions.shape + + # 1) Expand each batch's single scalar time 'tau' across all T steps + # so that shape => (B, T) + # e.g. if timesteps is (B,), replicate across T + if timesteps.dim() == 1 and timesteps.shape[0] == B: + # shape (B,) => (B,T) + timesteps = timesteps.unsqueeze(1).expand(-1, T) + else: + raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.") + + # 2) Standard action MLP step for shape => (B, T, w) + a_emb = self.W1(actions, cat_ids) + + # 3) Get the sinusoidal encoding (B, T, w) + tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype) + + # 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish + x = torch.cat([a_emb, tau_emb], dim=-1) + x = swish(self.W2(x, cat_ids)) + + # 5) Finally W3 => (B, T, w) + x = self.W3(x, cat_ids) + return x + + def expand_action_dimension(self, old_action_dim, new_action_dim): + """ + Expand action dimension by copying weights from existing dimensions. + + Args: + old_action_dim: Original action dimension + new_action_dim: New (larger) action dimension + """ + # Only W1 takes action_dim as input, so only expand its input dimension + self.W1.expand_action_dimension(old_action_dim, new_action_dim, expand_input=True, expand_output=False) diff --git a/vllm_omni/diffusion/models/gr00t/modeling/modules/flowmatching_modules.py b/vllm_omni/diffusion/models/gr00t/modeling/modules/flowmatching_modules.py new file mode 100644 index 00000000000..86b6c4b344b --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/modeling/modules/flowmatching_modules.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from torch import nn + +from vllm_omni.diffusion.utils.flow_matching import SinusoidalPositionalEncoding, swish + +__all__ = ["ActionEncoder", "SinusoidalPositionalEncoding", "swish"] + + +class ActionEncoder(nn.Module): + def __init__(self, action_dim, hidden_size): + super().__init__() + self.hidden_size = hidden_size + + # W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w} + self.W1 = nn.Linear(action_dim, hidden_size) # (d -> w) + self.W2 = nn.Linear(2 * hidden_size, hidden_size) # (2w -> w) + self.W3 = nn.Linear(hidden_size, hidden_size) # (w -> w) + + self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) + + def forward(self, actions, timesteps): + """ + actions: shape (B, T, action_dim) + timesteps: shape (B,) -- a single scalar per batch item + returns: shape (B, T, hidden_size) + """ + B, T, _ = actions.shape + + # 1) Expand each batch's single scalar time 'tau' across all T steps + # so that shape => (B, T) + # e.g. if timesteps is (B,), replicate across T + if timesteps.dim() == 1 and timesteps.shape[0] == B: + # shape (B,) => (B,T) + timesteps = timesteps.unsqueeze(1).expand(-1, T) + else: + raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.") + + # 2) Standard action MLP step for shape => (B, T, w) + a_emb = self.W1(actions) + + # 3) Get the sinusoidal encoding (B, T, w) + tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype) + + # 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish + x = torch.cat([a_emb, tau_emb], dim=-1) + x = swish(self.W2(x)) + + # 5) Finally W3 => (B, T, w) + x = self.W3(x) + + return x diff --git a/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py new file mode 100644 index 00000000000..d26ed38be86 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py @@ -0,0 +1,659 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import os +import re +import warnings +from copy import deepcopy +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import torchvision.transforms.v2 as transforms +from PIL import Image +from transformers import AutoProcessor, ProcessorMixin, Qwen3VLProcessor +from transformers.feature_extraction_utils import BatchFeature +from transformers.utils import cached_file +from vllm.logger import init_logger + +from vllm_omni.diffusion.models.gr00t.configs.embodiment_configs import ModalityConfig +from vllm_omni.diffusion.models.gr00t.dataio.embodiment_tags import EmbodimentTag +from vllm_omni.diffusion.models.gr00t.dataio.state_action.state_action_processor import StateActionProcessor +from vllm_omni.diffusion.models.gr00t.dataio.utils import parse_modality_configs, to_json_serializable + + +class LetterBoxTransform: + """Pad image to square dimensions by adding black bars to the smaller side.""" + + def __call__(self, img: torch.Tensor) -> torch.Tensor: + import math + + *leading_dims, c, h, w = img.shape + if h == w: + return img + max_dim = max(h, w) + pad_h, pad_w = max_dim - h, max_dim - w + pad_top, pad_left = pad_h // 2, pad_w // 2 + pad_bottom, pad_right = pad_h - pad_top, pad_w - pad_left + if leading_dims: + batch_size = math.prod(leading_dims) + img_r = img.reshape(batch_size, c, h, w) + padded = transforms.functional.pad(img_r, padding=[pad_left, pad_top, pad_right, pad_bottom], fill=0) + return padded.reshape(leading_dims + [c, max_dim, max_dim]) + return transforms.functional.pad(img, padding=[pad_left, pad_top, pad_right, pad_bottom], fill=0) + + +def _build_eval_image_transform( + image_target_size: list[int], + image_crop_size: list[int], +) -> transforms.Compose: + """Deterministic eval/inference image transform (letterbox → resize → centercrop → resize).""" + return transforms.Compose( + [ + transforms.ToImage(), + LetterBoxTransform(), + transforms.Resize(size=image_target_size), + transforms.CenterCrop(size=image_crop_size), + transforms.Resize(size=image_target_size), + ] + ) + + +logger = init_logger(__name__) + +# Suppress protobuf deprecation warnings +warnings.filterwarnings("ignore", category=DeprecationWarning, module="google.protobuf") + +EMBODIMENT_TAG_TO_PROJECTOR_INDEX = { + # Pretrain embodiment ids + "oxe_droid_relative_eef_relative_joint": 24, + "xdof_relative_eef_relative_joint": 27, + "xdof_relative_eef_relative_joint_subtask": 27, + "real_g1_relative_eef_relative_joints": 25, + "real_r1_pro_sharpa_relative_eef": 26, + "real_r1_pro_sharpa_relative_eef_human": 26, + "real_r1_pro_sharpa_relative_eef_maxinsights": 26, + "real_r1_pro_sharpa_relative_eef_mecka": 26, + # Posttrain embodiment ids + "unitree_g1_full_body_with_waist_height_nav_cmd": 25, + "unitree_g1_sonic": 11, + "simpler_env_google": 0, + "simpler_env_widowx": 1, + "libero_sim": 2, + "new_embodiment": 10, +} + +QWEN3_VL_2B_PROCESSOR = "Qwen/Qwen3-VL-2B-Instruct" + + +def build_processor(model_name: str, transformers_loading_kwargs: dict) -> Qwen3VLProcessor: + if model_name == "nvidia/Cosmos-Reason2-2B": + # Cosmos-Reason2-2B lacks a Qwen3VLProcessor; fall back to upstream artifacts. + logger.warning_once( + "Substituting processor from %s because %s does not ship one. " + "If you fine-tune Cosmos-Reason2-2B's tokenizer/image processor, " + "load the processor explicitly instead of relying on this fallback.", + QWEN3_VL_2B_PROCESSOR, + model_name, + ) + model_name = QWEN3_VL_2B_PROCESSOR + return Qwen3VLProcessor.from_pretrained(model_name, **transformers_loading_kwargs) + + +class Gr00tN1d7DataCollator: + def __init__( + self, + model_name: str, + model_type: str = "qwen", + transformers_loading_kwargs: dict = {}, + ): + self.processor = build_processor(model_name, transformers_loading_kwargs) + self.processor.tokenizer.padding_side = "left" + self.model_type = model_type + self.model_name = model_name + + def __call__(self, features: list[dict[str, Any]]) -> BatchFeature: + batch = {} + keys = list(set().union(*(elem.keys() for elem in features))) + + for key in keys: + values = [elem[key] for elem in features if key in elem] + if key == "vlm_content": + text_list = [] + image_inputs = [] + for v in values: + curr_text_list = [v["text"]] + + text_list += curr_text_list + curr_image_inputs = v["images"] + image_inputs += curr_image_inputs + + vlm_inputs = self.processor( + text=text_list, + images=image_inputs, + return_tensors="pt", + padding=True, + ) + for k, v in vlm_inputs.items(): + batch[k] = v + elif key in ( + "pixel_values", + "image_grid_thw", + "attention_mask", + "input_ids", + ): + raise Exception("Not implemented") + else: + batch[key] = torch.from_numpy(np.stack(values)) + return BatchFeature(data={"inputs": batch}) + + def __str__(self): + return f"Gr00tN1d7DataCollator(model_name={self.model_name}, model_type={self.model_type})" + + +class Gr00tN1d7Processor(ProcessorMixin): + data_collator_class = Gr00tN1d7DataCollator + + def __init__( + self, + modality_configs: dict[str, dict[str, ModalityConfig]], + statistics: (dict[str, dict[str, dict[str, dict[str, list[float]]]]] | None) = None, + use_percentiles: bool = False, + clip_outliers: bool = True, + image_crop_size: list[int] = None, + image_target_size: list[int] = None, + shortest_image_edge: int = 256, + crop_fraction: float = 0.95, + random_rotation_angle: int | None = None, + color_jitter_params: dict[str, float] | None = None, + formalize_language: bool = True, + model_name: str = "nvidia/Cosmos-Reason2-2B", + model_type: str = "qwen", + max_state_dim: int = 29, + max_action_dim: int = 29, + max_action_horizon: int = 50, + apply_sincos_state_encoding: bool = False, + use_relative_action: bool = False, + embodiment_id_mapping: dict[str, int] | None = None, + transformers_loading_kwargs: dict = {"trust_remote_code": True}, + exclude_state: bool = False, + # Normalization + use_mean_std: bool = False, + **kwargs, # absorb deprecated training-only keys from saved processor_config.json + ): + if kwargs: + logger.debug("Gr00tN1d7Processor: ignoring unknown keys: %s", list(kwargs)) + self.modality_configs = parse_modality_configs(modality_configs) + + # Initialize StateActionProcessor for state/action normalization + self.state_action_processor = StateActionProcessor( + modality_configs=modality_configs, + statistics=statistics, + use_percentiles=use_percentiles, + clip_outliers=clip_outliers, + apply_sincos_state_encoding=apply_sincos_state_encoding, + use_relative_action=use_relative_action, + ) + + self.use_percentiles = use_percentiles + self.use_mean_std = use_mean_std + self.clip_outliers = clip_outliers + self.apply_sincos_state_encoding = apply_sincos_state_encoding + self.use_relative_action = use_relative_action + + self.exclude_state = exclude_state + + self.formalize_language = formalize_language + self.model_name = model_name + self.model_type = model_type + + self.max_state_dim = max_state_dim + self.max_action_dim = max_action_dim + self.max_action_horizon = max_action_horizon + + self.image_crop_size = image_crop_size + self.image_target_size = image_target_size + self.random_rotation_angle = random_rotation_angle + self.color_jitter_params = color_jitter_params + self.processor = build_processor(model_name, transformers_loading_kwargs) + self.processor.tokenizer.padding_side = "left" + self.embodiment_id_mapping = embodiment_id_mapping or EMBODIMENT_TAG_TO_PROJECTOR_INDEX + for k, v in EMBODIMENT_TAG_TO_PROJECTOR_INDEX.items(): + if k not in self.embodiment_id_mapping: + self.embodiment_id_mapping[k] = v + self.shortest_image_edge = shortest_image_edge + self.crop_fraction = crop_fraction + + self.statistics: dict[str, dict[str, dict[str, dict[str, list[float]]]]] = {} + + # Eval/inference image transform + self.eval_image_transform = _build_eval_image_transform( + image_target_size, + image_crop_size, + ) + self._collator = self.data_collator_class( + model_name=model_name, + model_type=model_type, + transformers_loading_kwargs=transformers_loading_kwargs, + ) + + @property + def collator(self): + return self._collator + + def set_statistics( + self, + statistics: dict[str, dict[str, dict[str, dict[str, list[float]]]]], + override: bool = False, + ) -> None: + """Set dataset statistics for normalization.""" + for key in statistics: + if key not in self.statistics or override: + if override: + print(f"Overriding statistics for {key}") + self.statistics[key] = deepcopy(statistics[key]) + else: + print(f"Embodiment tag {key} already in statistics, skipping updating") + + self.state_action_processor.set_statistics(statistics, override=override) + + # Compute action dimensions for convenience + self.action_dim = {} + for embodiment_tag in self.state_action_processor.statistics: + self.action_dim[embodiment_tag] = self.state_action_processor.get_action_dim(embodiment_tag) + + def decode_action( + self, + action: np.ndarray, + embodiment_tag: EmbodimentTag, + state: dict[str, np.ndarray] | None = None, + ): + """Undo action normalization and convert relative actions to absolute.""" + # Split concatenated action into joint groups + out_dict = {} + start_idx = 0 + joint_groups = self.modality_configs[embodiment_tag.value]["action"].modality_keys + action_horizon = len(self.modality_configs[embodiment_tag.value]["action"].delta_indices) + for key in joint_groups: + joint_dim = self.state_action_processor.norm_params[embodiment_tag.value]["action"][key]["dim"].item() + out_dict[key] = action[..., :action_horizon, start_idx : start_idx + joint_dim] + start_idx += joint_dim + + # Use StateActionProcessor to unnormalize and convert to absolute + return self.state_action_processor.unapply_action(out_dict, embodiment_tag.value, state=state) + + def unapply( + self, + action: np.ndarray, + embodiment_tag: EmbodimentTag, + state: dict[str, np.ndarray] | None = None, + prev_action: dict[str, np.ndarray] | None = None, + ) -> dict[str, np.ndarray]: + """Undo action normalization and convert relative to absolute. + + Args: + action: Normalized action array of shape (..., action_horizon, action_dim) + embodiment_tag: Embodiment tag + state: State observations with "state." prefixed keys (for relative actions) + prev_action: Unused (kept for API compatibility) + + Returns: + Dict mapping "action." to unnormalized (absolute) action arrays. + """ + out_dict = {} + start_idx = 0 + joint_groups = self.modality_configs[embodiment_tag.value]["action"].modality_keys + action_horizon = len(self.modality_configs[embodiment_tag.value]["action"].delta_indices) + for key in joint_groups: + joint_dim = self.state_action_processor.norm_params[embodiment_tag.value]["action"][key]["dim"].item() + out_dict[key] = action[..., :action_horizon, start_idx : start_idx + joint_dim] + start_idx += joint_dim + + # Strip "state." prefix for StateActionProcessor + stripped_state = None + if state is not None: + stripped_state = {k.replace("state.", ""): v for k, v in state.items()} + + result = self.state_action_processor.unapply_action(out_dict, embodiment_tag.value, state=stripped_state) + return {f"action.{key}": value for key, value in result.items()} + + def process_observation(self, observation: dict[str, Any], embodiment_tag: EmbodimentTag): + """Process batched observation tensors for inference. + + Args: + observation: Dict with keys like "video.", "state.", "" + Video values expected as numpy arrays of shape (B, T, H, W, C). + embodiment_tag: Embodiment tag identifying the robot configuration. + + Returns: + BatchFeature with tokenized VLM inputs, state, embodiment_id, and action_mask. + """ + modality_config = self.modality_configs[embodiment_tag.value] + transformed_observation = {} + + # Normalize states + state_keys = modality_config["state"].modality_keys + state_data = {key: observation[f"state.{key}"] for key in state_keys} + exclude_state = self.exclude_state or getattr(modality_config["state"], "exclude_state", False) + if exclude_state: + normalized_states = torch.cat( + [torch.from_numpy(np.zeros_like(state_data[key])) for key in state_keys], dim=-1 + ) + else: + norm_state_dict = self.state_action_processor.apply_state( + state=state_data, embodiment_tag=embodiment_tag.value + ) + normalized_states = torch.cat([torch.from_numpy(norm_state_dict[key]) for key in state_keys], dim=-1) + + assert normalized_states.shape[1] <= self.max_state_dim, ( + f"State dimension {normalized_states.shape[1]} exceeds max_state_dim {self.max_state_dim}" + ) + padding_shape = ( + *normalized_states.shape[:-1], + self.max_state_dim - normalized_states.shape[-1], + ) + normalized_states = torch.cat([normalized_states, torch.zeros(padding_shape)], dim=-1) + transformed_observation["state"] = normalized_states + + image_keys = modality_config["video"].modality_keys + images_dict = {view: torch.from_numpy(observation[f"video.{view}"]) for view in image_keys} + images = torch.stack([images_dict[view] for view in image_keys], dim=2) # (B, T, V, H, W, C) + assert images.ndim == 6 + B, T, V, img_H, img_W, img_C = images.shape + + images_perm = images.permute(0, 1, 2, 5, 3, 4).reshape(B, T * V, img_C, img_H, img_W) # (B,T*V,C,H,W) + transformed_images = self.eval_image_transform(images_perm).numpy() + + language_key = modality_config["language"].modality_keys[0] + language = [ + re.sub(r"[^\w\s]", "", lang.lower()) if self.formalize_language else lang + for lang in observation[language_key] + ] + + texts, all_images = [], [] + for i in range(B): + vlm_inputs = self._apply_vlm_processing(transformed_images[i], language[i]) + vc = vlm_inputs["vlm_content"] + texts.append(vc["text"]) + all_images.extend(vc["images"]) + tokenized = self.processor(text=texts, images=all_images, return_tensors="pt", padding=True) + for k, v in tokenized.items(): + transformed_observation[k] = v + + embodiment_id = torch.ones(B, dtype=torch.int32) * self.embodiment_id_mapping[embodiment_tag.value] + transformed_observation["embodiment_id"] = embodiment_id + + action_config = modality_config["action"] + action_horizon = len(action_config.delta_indices) + assert action_horizon <= self.max_action_horizon, ( + f"Action horizon {action_horizon} (from delta_indices) exceeds" + f" max_action_horizon {self.max_action_horizon}. Increase model config" + f" action_horizon to >= {action_horizon}." + ) + action_mask = torch.zeros((B, self.max_action_horizon), dtype=torch.float32) + if action_horizon > 0: + action_mask[:, :action_horizon] = 1.0 + transformed_observation["action_mask"] = action_mask + + return BatchFeature(transformed_observation) + + def _apply_vlm_processing(self, images: np.ndarray, language: str) -> BatchFeature: + pil_images = [Image.fromarray(np.transpose(v, (1, 2, 0))) for v in images] + conversation = [ + { + "role": "user", + "content": [ + *[{"type": "image", "image": img} for img in pil_images], + {"type": "text", "text": language}, + ], + } + ] + + text = self.processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False) + return { + "vlm_content": { + "text": text, + "images": pil_images, + "conversation": conversation, + } + } + + def __call__( + self, + messages: list[dict[str, Any]], + ): + assert len(messages) == 1 + content = messages[0]["content"] + embodiment_tag = content.embodiment + action_data = content.actions + state_data = content.states + + norm_state_dict, normalized_actions = self.state_action_processor.apply( + state=state_data, + action=action_data, + embodiment_tag=embodiment_tag.value, + ) + + if normalized_actions: + action_keys = self.modality_configs[embodiment_tag.value]["action"].modality_keys + normalized_actions = torch.cat( + [torch.from_numpy(normalized_actions[key]) for key in action_keys], + dim=-1, + ) # (t, d) + action_dim = normalized_actions.shape[1] + normalized_actions = torch.cat( + [ + normalized_actions, + torch.zeros( + normalized_actions.shape[0], + self.max_action_dim - normalized_actions.shape[1], + ), + ], + dim=-1, + ) # (t, max_action_dim) + action_horizon = normalized_actions.shape[0] + assert action_horizon <= self.max_action_horizon, ( + f"Action sequence length {action_horizon} exceeds max_action_horizon" + f" {self.max_action_horizon}. Increase model config action_horizon to" + f" >= {action_horizon}." + ) + normalized_actions = torch.cat( + [ + normalized_actions, + torch.zeros( + self.max_action_horizon - normalized_actions.shape[0], + self.max_action_dim, + ), + ], + dim=0, + ) # (max_action_horizon, max_action_dim) + action_mask = torch.ones_like(normalized_actions) + action_mask[action_horizon:] = 0 + action_mask[:, action_dim:] = 0 + else: + normalized_actions = None + action_mask = None + + state_keys = self.modality_configs[embodiment_tag.value]["state"].modality_keys + exclude_state = self.exclude_state or getattr( + self.modality_configs[embodiment_tag.value]["state"], "exclude_state", False + ) + if exclude_state: + normalized_states = torch.cat( + [torch.from_numpy(np.zeros_like(state_data[key])) for key in state_keys], dim=-1 + ) + else: + normalized_states = torch.cat([torch.from_numpy(norm_state_dict[key]) for key in state_keys], dim=-1) + normalized_states = torch.cat( + [ + normalized_states, + torch.zeros( + normalized_states.shape[0], + self.max_state_dim - normalized_states.shape[1], + ), + ], + dim=-1, + ) + + image_transform = self.eval_image_transform + image_keys = self.modality_configs[embodiment_tag.value]["video"].modality_keys + + if self.formalize_language: + language = content.text.lower() + language = re.sub(r"[^\w\s]", "", language) + else: + language = content.text + + vlm_inputs = self._get_vlm_inputs( + image_keys=image_keys, + images=content.images, + masks=content.masks, + image_transform=image_transform, + language=language, + ) + + transformed_inputs = { + "state": normalized_states.to(torch.get_default_dtype()), + } + if normalized_actions is not None: + transformed_inputs["action"] = normalized_actions.to(torch.get_default_dtype()) + # Add VLM inputs + transformed_inputs.update(vlm_inputs) + if action_mask is not None: + transformed_inputs["action_mask"] = action_mask + transformed_inputs["embodiment_id"] = self.embodiment_id_mapping[embodiment_tag.value] + return transformed_inputs + + def _get_vlm_inputs( + self, + image_keys: list[str], + images: list[Image.Image], + masks: dict[str, list[np.ndarray]] | None, + image_transform: transforms.Compose, + language: str, + ): + temporal_stacked_images = {} + + if masks is not None: + raise ValueError("Mask-based transforms are not supported at inference.") + for view in image_keys: + assert view in images, f"{view} not in {images}" + temporal_stacked_images[view] = torch.stack([image_transform(img) for img in images[view]]) # (T, C, H, W) + + for k, v in temporal_stacked_images.items(): + assert isinstance(k, str), f"{k} is not a string" + assert isinstance(v, torch.Tensor), f"{v} is not a torch tensor" + assert v.ndim == 4, f"{v} is not a 4D tensor" + assert v.dtype == torch.uint8, f"{v} is not a uint8 tensor" + assert v.shape[1] == 3, f"{v} is not a 3 channel tensor" + + stacked_images = ( + torch.stack([temporal_stacked_images[view] for view in image_keys], dim=1).flatten(0, 1).numpy() + ) # (T*V, C, H, W), processor expects numpy array + + vlm_inputs = self._apply_vlm_processing(stacked_images, language) + return vlm_inputs + + def save_pretrained(self, save_directory: str | Path) -> list[Path]: + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) + main_config_file = save_directory / "processor_config.json" + statistics_file = save_directory / "statistics.json" + embodiment_id_file = save_directory / "embodiment_id.json" + + config = { + "processor_class": self.__class__.__name__, + "processor_kwargs": { + "modality_configs": to_json_serializable(self.modality_configs), + "image_crop_size": self.image_crop_size, + "image_target_size": self.image_target_size, + "random_rotation_angle": self.random_rotation_angle, + "color_jitter_params": self.color_jitter_params, + "shortest_image_edge": self.shortest_image_edge, + "crop_fraction": self.crop_fraction, + "model_name": self.model_name, + "model_type": self.model_type, + "formalize_language": self.formalize_language, + "max_state_dim": self.max_state_dim, + "max_action_dim": self.max_action_dim, + "max_action_horizon": self.max_action_horizon, + "use_percentiles": self.use_percentiles, + "use_mean_std": self.use_mean_std, + "clip_outliers": self.clip_outliers, + "apply_sincos_state_encoding": self.apply_sincos_state_encoding, + "use_relative_action": self.use_relative_action, + "exclude_state": self.exclude_state, + }, + } + with open(main_config_file, "w") as f: + json.dump(config, f, indent=2) + with open(statistics_file, "w") as f: + json.dump( + to_json_serializable(self.state_action_processor.statistics), + f, + indent=2, + ) + with open(embodiment_id_file, "w") as f: + json.dump(self.embodiment_id_mapping, f, indent=2) + return [main_config_file, statistics_file, embodiment_id_file] + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str | Path, **kwargs): + transformers_loading_kwargs = kwargs.pop("transformers_loading_kwargs", {"trust_remote_code": True}) + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + config_file = pretrained_model_name_or_path / "processor_config.json" + statistics_file = pretrained_model_name_or_path / "statistics.json" + embodiment_id_file = pretrained_model_name_or_path / "embodiment_id.json" + is_local = os.path.isdir(pretrained_model_name_or_path) + if not is_local: + config_file = Path(cached_file(pretrained_model_name_or_path, "processor_config.json")) + statistics_file = Path(cached_file(pretrained_model_name_or_path, "statistics.json")) + embodiment_id_file = Path(cached_file(pretrained_model_name_or_path, "embodiment_id.json")) + + with open(config_file) as f: + config = json.load(f) + with open(statistics_file) as f: + statistics = json.load(f) + if embodiment_id_file.exists(): + with open(embodiment_id_file) as f: + embodiment_id_mapping = json.load(f) + else: + embodiment_id_mapping = None + processor_kwargs = config["processor_kwargs"] + processor_kwargs["statistics"] = statistics + processor_kwargs["embodiment_id_mapping"] = embodiment_id_mapping + + # Backfill missing fields from older checkpoints. + processor_kwargs.setdefault("model_name", "nvidia/Cosmos-Reason2-2B") + processor_kwargs.setdefault("model_type", "qwen") + processor_kwargs.setdefault("clip_outliers", True) + + # Directly override other processor kwargs + if kwargs: + # Override modality configs while keeping pretrained embodiment configs + modality_configs = kwargs.pop("modality_configs", {}) + for embodiment_tag, modality_config in modality_configs.items(): + processor_kwargs["modality_configs"][embodiment_tag] = modality_config + override_keys = [ + "random_rotation_angle", + "color_jitter_params", + "use_relative_action", + "exclude_state", + "use_mean_std", + "model_name", + "model_type", + "max_action_horizon", + "max_state_dim", + "max_action_dim", + ] + for key in override_keys: + if key in kwargs: + override = kwargs.pop(key) + if override is not None: + processor_kwargs[key] = override + return cls(**processor_kwargs, transformers_loading_kwargs=transformers_loading_kwargs) + + +AutoProcessor.register("Gr00tN1d7", Gr00tN1d7Processor) diff --git a/vllm_omni/diffusion/models/gr00t/pipeline_gr00t.py b/vllm_omni/diffusion/models/gr00t/pipeline_gr00t.py new file mode 100644 index 00000000000..63f8e4ce83e --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/pipeline_gr00t.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable, Mapping +from typing import Any + +import numpy as np +import torch +from torch import nn +from vllm.logger import init_logger + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.models.gr00t.policy import Gr00tPolicy +from vllm_omni.diffusion.request import DUMMY_DIFFUSION_REQUEST_ID, OmniDiffusionRequest + +logger = init_logger(__name__) + + +def _to_float32_action_dict(actions: Mapping[str, Any]) -> dict[str, np.ndarray]: + converted = {str(key): np.asarray(value, dtype=np.float32) for key, value in actions.items()} + if not converted: + raise RuntimeError("GR00T policy returned an empty action dict.") + return converted + + +class Gr00tN1d7Pipeline(nn.Module): + """GR00T N1.7 policy pipeline backed by vLLM-Omni's local GR00T port. + + vLLM-Omni owns the serving integration: OpenPI observations arrive through + `sampling_params.extra_args["robot_obs"]`, this pipeline runs GR00T policy + inference, and actions are returned through `DiffusionOutput.multimodal_output`. + """ + + EXTRA_BODY_PARAMS = frozenset({"robot_obs"}) + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = "") -> None: + super().__init__() + model_config = od_config.model_config + self.model_path = od_config.model + self.embodiment_tag = str(model_config.get("embodiment_tag") or "OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT") + self.strict = bool(model_config.get("strict", True)) + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + logger.info("Loading GR00T N1.7 policy from %s with embodiment_tag=%s", self.model_path, self.embodiment_tag) + self.policy = Gr00tPolicy( + model_path=self.model_path, + embodiment_tag=self.embodiment_tag, + device=self.device, + strict=self.strict, + ) + + def reset(self) -> dict[str, Any]: + return self.policy.reset() or {} + + @property + def weights_sources(self) -> tuple[Any, ...]: + return () + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + consumed = list(weights) + if consumed: + raise RuntimeError( + f"Gr00tN1d7Pipeline.load_weights received {len(consumed)} weight tensors; " + "weights_sources=() should prevent this. GR00T weights are loaded directly by Gr00tPolicy." + ) + return set() + + def _dummy_actions(self) -> dict[str, np.ndarray]: + embodiment_value = self.policy.embodiment_tag.value + action_config = self.policy.modality_configs["action"] + horizon = len(action_config.delta_indices) + norm_params = self.policy.processor.state_action_processor.norm_params[embodiment_value]["action"] + actions = {} + for key in action_config.modality_keys: + dim = norm_params[key]["dim"] + dim = int(dim.item() if hasattr(dim, "item") else dim) + actions[key] = np.zeros((1, horizon, dim), dtype=np.float32) + return actions + + @torch.inference_mode() + def forward(self, req: OmniDiffusionRequest, **kwargs) -> DiffusionOutput: + del kwargs + extra_args = req.sampling_params.extra_args or {} + robot_obs = extra_args.get("robot_obs") + if robot_obs is None: + if req.request_id == DUMMY_DIFFUSION_REQUEST_ID: + return DiffusionOutput(multimodal_output={"actions": self._dummy_actions()}) + return DiffusionOutput(error="Gr00tN1d7Pipeline.forward expects sampling_params.extra_args['robot_obs'].") + if not isinstance(robot_obs, Mapping): + return DiffusionOutput(error=f"robot_obs must be a dict, got {type(robot_obs).__name__}.") + + if extra_args.get("reset"): + self.reset() + + policy_obs = _normalize_observation(robot_obs, language_key=self.policy.language_key) + result = self.policy.get_action(policy_obs) + actions = result[0] if isinstance(result, tuple) else result + if not isinstance(actions, Mapping): + return DiffusionOutput(error=f"GR00T policy returned {type(actions).__name__}; expected dict actions.") + return DiffusionOutput(multimodal_output={"actions": _to_float32_action_dict(actions)}) + + +def _normalize_observation(robot_obs: Mapping[str, Any], *, language_key: str) -> dict[str, Any]: + obs: dict[str, Any] = {} + if "video" in robot_obs: + obs["video"] = robot_obs["video"] + elif "images" in robot_obs: + obs["video"] = robot_obs["images"] + if "state" in robot_obs: + obs["state"] = robot_obs["state"] + if "language" in robot_obs: + obs["language"] = robot_obs["language"] + else: + prompt = robot_obs.get("prompt") + if prompt is not None: + obs["language"] = {language_key: [[str(prompt)]]} + return obs diff --git a/vllm_omni/diffusion/models/gr00t/policy.py b/vllm_omni/diffusion/models/gr00t/policy.py new file mode 100644 index 00000000000..b37c2eb46ba --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/policy.py @@ -0,0 +1,437 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from transformers import AutoModel, AutoProcessor + +from vllm_omni.diffusion.models.gr00t.dataio.embodiment_tags import FINETUNE_ONLY_TAGS, POSTTRAIN_TAGS, EmbodimentTag +from vllm_omni.diffusion.models.gr00t.dataio.types import MessageType, ModalityConfig, VLAStepData +from vllm_omni.diffusion.models.gr00t.modeling.gr00t_n1d7 import Gr00tN1d7 +from vllm_omni.diffusion.models.gr00t.modeling.processing_gr00t_n1d7 import Gr00tN1d7Processor + + +def _rec_to_dtype(value: Any, dtype: torch.dtype) -> Any: + """Recursively convert floating-point tensors in nested collator output.""" + if isinstance(value, torch.Tensor): + return value.to(dtype=dtype) if torch.is_floating_point(value) else value + if isinstance(value, dict) or hasattr(value, "items"): + return {key: _rec_to_dtype(item, dtype) for key, item in value.items()} + if isinstance(value, list): + return [_rec_to_dtype(item, dtype) for item in value] + return value + + +class Gr00tPolicy: + """Core policy class for Gr00t model inference. + + This policy handles the end-to-end inference pipeline: + 1. Validates input observations + 2. Processes observations with pretrained VLA processor + 3. Runs model inference + 4. Decodes and returns actions + + The policy expects observations with specific modalities (video, state, language) + and returns actions in the format defined by the model's modality configuration. + """ + + def __init__( + self, + embodiment_tag: EmbodimentTag | str, + model_path: str, + *, + device: int | str, + strict: bool = True, + ): + """Initialize the Gr00t Policy. + + Args: + embodiment_tag: The embodiment tag defining the robot/environment type. + Accepts an EmbodimentTag enum or a string (resolved case-insensitively). + model_path: Path to the pretrained model checkpoint directory + device: Device to run the model on (e.g., 'cuda:0', 0, 'cpu') + strict: Whether to enforce strict input validation (default: True) + """ + self.strict = strict + if isinstance(embodiment_tag, str): + embodiment_tag = EmbodimentTag.resolve(embodiment_tag) + model_dir = Path(model_path) + + # Load the pretrained model and move to target device with bfloat16 precision + self.model: Gr00tN1d7 = AutoModel.from_pretrained(model_dir) + self.model.eval() + self.model.to(device=device, dtype=torch.bfloat16) + + # Training saves processor files under a "processor/" subdirectory, but + # AutoProcessor expects them at the model root. Fall back to the + # subdirectory when the root lacks a processor_config.json. + processor_dir = ( + model_dir / "processor" + if (model_dir / "processor").is_dir() and not (model_dir / "processor_config.json").exists() + else model_dir + ) + self.processor: Gr00tN1d7Processor = AutoProcessor.from_pretrained(processor_dir) + + # Store embodiment-specific configurations + self.embodiment_tag = embodiment_tag + all_modality_configs = self.processor.modality_configs + if self.embodiment_tag.value not in all_modality_configs: + # Map raw checkpoint tag values to user-friendly enum names where possible. + supported_lines = [] + for tag_value in sorted(all_modality_configs.keys()): + enum_name = EmbodimentTag.reverse_lookup(tag_value) + if enum_name != tag_value: + supported_lines.append(f" {enum_name:30s} (--embodiment-tag {enum_name})") + else: + supported_lines.append(f" {tag_value:30s} (internal, no public enum)") + supported_str = "\n".join(supported_lines) + + hint = "" + if self.embodiment_tag in POSTTRAIN_TAGS: + hint = ( + f"\n\nHint: '{self.embodiment_tag.name}' is a posttrain tag that requires " + f"a finetuned checkpoint, not the base model. " + f"See the example READMEs for how to finetune and download checkpoints." + ) + elif self.embodiment_tag in FINETUNE_ONLY_TAGS: + hint = ( + f"\n\nHint: '{self.embodiment_tag.name}' is for finetuning custom robots. " + f"Use it with launch_finetune.py, not with the base model directly." + ) + + raise ValueError( + f"Embodiment tag '{self.embodiment_tag.name}' " + f"(value='{self.embodiment_tag.value}') is not supported " + f"by this checkpoint.\n\n" + f"Supported tags in this checkpoint:\n{supported_str}" + f"{hint}" + ) + self.modality_configs = { + k: v for k, v in all_modality_configs[self.embodiment_tag.value].items() if k != "rl_info" + } + self.collate_fn = self.processor.collator + + # Extract and validate language configuration + # Embodiments may define multiple language keys (e.g. paraphrases); use only the first at inference. + language_keys = self.modality_configs["language"].modality_keys + language_delta_indices = self.modality_configs["language"].delta_indices + assert len(language_keys) >= 1, "At least one language key is required" + assert len(language_delta_indices) == 1, "Only one language delta index is supported" + self.language_key = language_keys[0] + + def _unbatch_observation(self, value: dict[str, Any]) -> list[dict[str, Any]]: + """Unbatch a batched observation into a list of single observations. + + Args: + value: Batched observation with shape (B, ...) for each modality + + Returns: + List of B observations, each with the batch dimension removed + """ + unbatched_obs = [] + # Infer batch size from the first video key + batch_size = value["video"][list(value["video"].keys())[0]].shape[0] + + # Split each modality along the batch dimension + for i in range(batch_size): + unbatched_value = { + "video": {k: v[i] for k, v in value["video"].items()}, + "state": {k: v[i] for k, v in value["state"].items()}, + "language": {k: v[i] for k, v in value["language"].items()}, + } + unbatched_obs.append(unbatched_value) + return unbatched_obs + + def _to_vla_step_data(self, observation: dict[str, Any]) -> VLAStepData: + """Convert a single observation into a VLAStepData object for processing. + + Args: + observation: Single observation dict with video, state, and language + + Returns: + VLAStepData object ready for processor input + """ + return VLAStepData( + images=observation["video"], + states=observation["state"], + actions={}, # No ground truth actions during inference + text=observation["language"][self.language_key][0], + embodiment=self.embodiment_tag, + ) + + def check_observation(self, observation: dict[str, Any]) -> None: + """Validate that the observation has the correct structure and types. + + This method ensures that all required modalities are present and that their + data types, shapes, and dimensions match the model's expectations. + + Expected observation structure: + - video: dict[str, np.ndarray[np.uint8, (B, T, H, W, C)]] + - B: batch size + - T: temporal horizon (number of frames) + - H, W: image height and width + - C: number of channels (must be 3 for RGB) + - state: dict[str, np.ndarray[np.float32, (B, T, D)]] + - B: batch size + - T: temporal horizon (number of state observations) + - D: state dimension + - language: dict[str, list[list[str]]] + - Shape: (B, T) where each element is a string + - T: temporal horizon (typically 1 for language) + + Args: + observation: Dictionary containing video, state, and language modalities + + Raises: + AssertionError: If any validation check fails + """ + # Check that observation contains all required top-level modality keys + for modality in ["video", "state", "language"]: + assert modality in observation, f"Observation must contain a '{modality}' key" + assert isinstance(observation[modality], dict), ( + f"Observation '{modality}' must be a dictionary. " + f"Got {type(observation[modality])}: {observation[modality]}" + ) + + # Track batch size across modalities to ensure consistency + bs = -1 + + for video_key in self.modality_configs["video"].modality_keys: + assert video_key in observation["video"], f"Video key '{video_key}' must be in observation" + + # Set or verify batch size consistency across all video keys + if bs == -1: + bs = len(observation["video"][video_key]) + else: + assert len(observation["video"][video_key]) == bs, ( + f"Video key '{video_key}' must have batch size {bs}. Got {len(observation['video'][video_key])}" + ) + + batched_video = observation["video"][video_key] + + # Verify data type is numpy array + assert isinstance(batched_video, np.ndarray), ( + f"Video key '{video_key}' must be a numpy array. Got {type(batched_video)}" + ) + + # Verify dtype is uint8 (standard for image data, range 0-255) + assert batched_video.dtype == np.uint8, ( + f"Video key '{video_key}' must be a numpy array of type np.uint8. Got {batched_video.dtype}" + ) + + # Verify shape has 5 dimensions: (B, T, H, W, C) + assert batched_video.ndim == 5, ( + f"Video key '{video_key}' must be a numpy array of shape (B, T, H, W, C), got {batched_video.shape}" + ) + + # Verify temporal dimension matches the expected horizon from config + assert batched_video.shape[1] == len(self.modality_configs["video"].delta_indices), ( + f"Video key '{video_key}'s horizon must be " + f"{len(self.modality_configs['video'].delta_indices)}. Got {batched_video.shape[1]}" + ) + + # Verify channel dimension is 3 (RGB images) + assert batched_video.shape[-1] == 3, ( + f"Video key '{video_key}'s channel 'C' must be 3. Got {batched_video.shape[-1]}" + ) + + for state_key in self.modality_configs["state"].modality_keys: + # Check that the expected state key exists in the observation + # Must happen before indexing; see video validation above. + assert state_key in observation["state"], f"State key '{state_key}' must be in observation" + + # Set or verify batch size consistency across all state keys + if bs == -1: + bs = len(observation["state"][state_key]) + else: + assert len(observation["state"][state_key]) == bs, ( + f"State key '{state_key}' must have batch size {bs}. Got {len(observation['state'][state_key])}" + ) + + batched_state = observation["state"][state_key] + + # Verify data type is numpy array + assert isinstance(batched_state, np.ndarray), ( + f"State key '{state_key}' must be a numpy array. Got {type(batched_state)}" + ) + + # Verify dtype is float32 (standard for continuous state values) + assert batched_state.dtype == np.float32, ( + f"State key '{state_key}' must be a numpy array of type np.float32. Got {batched_state.dtype}" + ) + + # Verify shape has 3 dimensions: (B, T, D) + assert batched_state.ndim == 3, ( + f"State key '{state_key}' must be a numpy array of shape (B, T, D), got {batched_state.shape}" + ) + + # Verify temporal dimension matches the expected horizon from config + assert batched_state.shape[1] == len(self.modality_configs["state"].delta_indices), ( + f"State key '{state_key}'s horizon must be " + f"{len(self.modality_configs['state'].delta_indices)}. Got {batched_state.shape[1]}" + ) + + for language_key in self.modality_configs["language"].modality_keys: + # Check that the expected language key exists in the observation + # Must happen before indexing; see video validation above. + assert language_key in observation["language"], f"Language key '{language_key}' must be in observation" + + # Set or verify batch size consistency (language uses len instead of .shape) + if bs == -1: + bs = len(observation["language"][language_key]) + else: + assert len(observation["language"][language_key]) == bs, ( + f"Language key '{language_key}' must have batch size {bs}. " + f"Got {len(observation['language'][language_key])}" + ) + + batched_language: list[list[str]] = observation["language"][language_key] + + # Verify outer structure is a list (batch dimension) + assert isinstance(batched_language, list), ( + f"Language key '{language_key}' must be a list. Got {type(batched_language)}" + ) + + # Validate each batch item + for batch_item in batched_language: + # Verify temporal dimension matches expected horizon + assert len(batch_item) == len(self.modality_configs["language"].delta_indices), ( + f"Language key '{language_key}'s horizon must be " + f"{len(self.modality_configs['language'].delta_indices)}. Got {len(batched_language)}" + ) + + # Verify inner structure is also a list (temporal dimension) + assert isinstance(batch_item, list), f"Language batch item must be a list. Got {type(batch_item)}" + + # Current implementation expects exactly one language instruction per timestep + assert len(batch_item) == 1, f"Language batch item must have exactly one item. Got {len(batch_item)}" + + # Verify the instruction itself is a string + assert isinstance(batch_item[0], str), ( + f"Language batch item must be a string. Got {type(batch_item[0])}" + ) + + def _get_action( + self, observation: dict[str, Any], options: dict[str, Any] | None = None + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Internal method to compute actions from observations. + + Pipeline: + 1. Unbatch observations into individual samples + 2. Convert each to VLAStepData and process + 3. Collate into model input batch + 4. Run model inference + 5. Decode and unnormalize actions + + Args: + observation: Batched observation dictionary + options: Optional parameters (currently unused) + + Returns: + Tuple of (actions_dict, info_dict) + """ + # Step 1: Split batched observation into individual observations + unbatched_observations = self._unbatch_observation(observation) + processed_inputs = [] + + # Step 2: Process each observation through the VLA processor + states = [] + for obs in unbatched_observations: + vla_step_data = self._to_vla_step_data(obs) + states.append(vla_step_data.states) # dict[str, np.ndarray[np.float32, (T, D)]] + messages = [{"type": MessageType.EPISODE_STEP.value, "content": vla_step_data}] + processed_inputs.append(self.processor(messages)) + + # Step 3: Collate processed inputs into a single batch for model + collated_inputs = self.collate_fn(processed_inputs) + collated_inputs = _rec_to_dtype(collated_inputs, dtype=torch.bfloat16) + + # Step 4: Run model inference to predict actions + with torch.inference_mode(): + model_pred = self.model.get_action(**collated_inputs) + normalized_action = model_pred["action_pred"].float() + + # Step 5: Decode actions from normalized space back to physical units + batched_states = {} + for k in self.modality_configs["state"].modality_keys: + batched_states[k] = np.stack([s[k] for s in states], axis=0) # (B, T, D) + unnormalized_action = self.processor.decode_action( + normalized_action.cpu().numpy(), self.embodiment_tag, batched_states + ) + + # Cast all actions to float32 for consistency + casted_action = {key: value.astype(np.float32) for key, value in unnormalized_action.items()} + return casted_action, {} + + def check_action(self, action: dict[str, Any]) -> None: + """Validate that the action has the correct structure and types. + + This method ensures that all required action keys are present and that their + data types, shapes, and dimensions match the model's action space. + + Expected action structure: + - action: dict[str, np.ndarray[np.float32, (B, T, D)]] + - B: batch size + - T: action horizon (number of future action steps) + - D: action dimension (e.g., joint positions, velocities, gripper state) + + Args: + action: Dictionary containing action arrays for each action key + + Raises: + AssertionError: If any validation check fails + """ + # Validate each action key defined in the modality config + for action_key in self.modality_configs["action"].modality_keys: + # Check that the expected action key exists + assert action_key in action, f"Action key '{action_key}' must be in action" + + action_arr = action[action_key] + + # Verify data type is numpy array + assert isinstance(action_arr, np.ndarray), ( + f"Action key '{action_key}' must be a numpy array. Got {type(action_arr)}" + ) + + # Verify dtype is float32 (standard for continuous actions) + assert action_arr.dtype == np.float32, ( + f"Action key '{action_key}' must be a numpy array of type np.float32. Got {action_arr.dtype}" + ) + + # Verify shape has 3 dimensions: (B, T, D) + assert action_arr.ndim == 3, ( + f"Action key '{action_key}' must be a numpy array of shape (B, T, D), got {action_arr.shape}" + ) + + # Verify action horizon matches the expected temporal dimension from config + assert action_arr.shape[1] == len(self.modality_configs["action"].delta_indices), ( + f"Action key '{action_key}'s horizon must be " + f"{len(self.modality_configs['action'].delta_indices)}. Got {action_arr.shape[1]}" + ) + + def get_action( + self, observation: dict[str, Any], options: dict[str, Any] | None = None + ) -> tuple[dict[str, Any], dict[str, Any]]: + if self.strict: + self.check_observation(observation) + action, info = self._get_action(observation, options) + if self.strict: + self.check_action(action) + return action, info + + def get_modality_config(self) -> dict[str, ModalityConfig]: + return self.modality_configs + + def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]: + """Reset the policy to its initial state. + + Args: + options: Dictionary containing the options for the reset + + Returns: + Dictionary containing the info after resetting the policy + """ + return {} diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index a381640b507..bbae141116b 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -151,6 +151,11 @@ "pipeline_internvla_a1", "InternVLAA1Pipeline", ), + "Gr00tN1d7Pipeline": ( + "gr00t", + "pipeline_gr00t", + "Gr00tN1d7Pipeline", + ), "LongCatImageEditPipeline": ( "longcat_image", "pipeline_longcat_image_edit", diff --git a/vllm_omni/diffusion/utils/flow_matching.py b/vllm_omni/diffusion/utils/flow_matching.py new file mode 100644 index 00000000000..e81d405ec23 --- /dev/null +++ b/vllm_omni/diffusion/utils/flow_matching.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math + +import torch +from torch import nn + + +def swish(x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(x) + + +class SinusoidalPositionalEncoding(nn.Module): + def __init__(self, embedding_dim: int) -> None: + super().__init__() + self.embedding_dim = embedding_dim + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + timesteps = timesteps.float() + B, T = timesteps.shape + device = timesteps.device + + half_dim = self.embedding_dim // 2 + exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (math.log(10000.0) / half_dim) + freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim) + enc = torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1) # (B, T, embedding_dim) + return enc diff --git a/vllm_omni/model_executor/models/gr00t/__init__.py b/vllm_omni/model_executor/models/gr00t/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/vllm_omni/model_executor/models/gr00t/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/model_executor/models/gr00t/pipeline.py b/vllm_omni/model_executor/models/gr00t/pipeline.py new file mode 100644 index 00000000000..38a76f29769 --- /dev/null +++ b/vllm_omni/model_executor/models/gr00t/pipeline.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""GR00T N1.7 single-stage policy topology.""" + +from vllm_omni.config.stage_config import ( + PipelineConfig, + StageExecutionType, + StagePipelineConfig, +) + +GR00T_N1D7_PIPELINE = PipelineConfig( + model_type="Gr00tN1d7", + model_arch="Gr00tN1d7Pipeline", + hf_architectures=("Gr00tN1d7",), + stages=( + StagePipelineConfig( + stage_id=0, + model_stage="diffusion", + execution_type=StageExecutionType.DIFFUSION, + input_sources=(), + final_output=True, + final_output_type="actions", + model_arch="Gr00tN1d7Pipeline", + ), + ), +)