From e2e2195762ec7979d42a6256c926a06a72a27f03 Mon Sep 17 00:00:00 2001 From: Zhengyuan Su Date: Thu, 21 May 2026 20:50:21 +0700 Subject: [PATCH 1/8] feat(gr00t): add GR00T-N1.7 pipeline with OpenPI serving Signed-off-by: Zhengyuan Su --- docs/.nav.yml | 1 + docs/models/supported_models.md | 1 + .../examples/online_serving/gr00t.md | 26 + examples/online_serving/gr00t/README.md | 33 + requirements/common.txt | 1 + tests/diffusion/models/gr00t/__init__.py | 0 tests/diffusion/models/gr00t/test_pipeline.py | 142 ++++ .../test_diffusion_engine_actions.py | 71 ++ tests/e2e/online_serving/test_gr00t_openpi.py | 44 + .../openai_api/test_openpi_serving.py | 10 + tests/gr00t/__init__.py | 2 + tests/gr00t/openpi_client_helper.py | 183 +++++ vllm_omni/config/pipeline_registry.py | 6 +- vllm_omni/deploy/Gr00tN1d7.yaml | 40 + vllm_omni/diffusion/data.py | 17 +- vllm_omni/diffusion/diffusion_engine.py | 23 +- vllm_omni/diffusion/models/gr00t/__init__.py | 9 + .../models/gr00t/configs/__init__.py | 2 + .../gr00t/configs/embodiment/__init__.py | 2 + .../configs/embodiment/embodiment_configs.py | 251 ++++++ .../models/gr00t/configs/model/__init__.py | 8 + .../models/gr00t/configs/model/gr00t_n1d7.py | 184 +++++ .../diffusion/models/gr00t/dataio/__init__.py | 2 + .../models/gr00t/dataio/collator/__init__.py | 6 + .../models/gr00t/dataio/collator/collators.py | 27 + .../models/gr00t/dataio/embodiment_tags.py | 207 +++++ .../models/gr00t/dataio/interfaces.py | 143 ++++ .../gr00t/dataio/state_action/__init__.py | 2 + .../dataio/state_action/action_chunking.py | 666 +++++++++++++++ .../models/gr00t/dataio/state_action/pose.py | 721 ++++++++++++++++ .../state_action/state_action_processor.py | 672 +++++++++++++++ .../diffusion/models/gr00t/dataio/types.py | 124 +++ .../diffusion/models/gr00t/dataio/utils.py | 303 +++++++ .../models/gr00t/modeling/__init__.py | 7 + .../models/gr00t/modeling/gr00t_n1d7.py | 776 ++++++++++++++++++ .../gr00t/modeling/image_augmentations.py | 564 +++++++++++++ .../models/gr00t/modeling/modules/__init__.py | 2 + .../models/gr00t/modeling/modules/dit.py | 478 +++++++++++ .../modules/embodiment_conditioned_mlp.py | 228 +++++ .../modeling/modules/flowmatching_modules.py | 111 +++ .../gr00t/modeling/processing_gr00t_n1d7.py | 762 +++++++++++++++++ .../diffusion/models/gr00t/pipeline_gr00t.py | 137 ++++ vllm_omni/diffusion/models/gr00t/policy.py | 717 ++++++++++++++++ .../diffusion/models/gr00t/policy_base.py | 132 +++ .../models/internvla_a1/adapter_qwen3_vl.py | 14 +- vllm_omni/diffusion/registry.py | 6 + .../model_executor/models/gr00t/__init__.py | 2 + .../model_executor/models/gr00t/pipeline.py | 26 + 48 files changed, 7868 insertions(+), 23 deletions(-) create mode 100644 docs/user_guide/examples/online_serving/gr00t.md create mode 100644 examples/online_serving/gr00t/README.md create mode 100644 tests/diffusion/models/gr00t/__init__.py create mode 100644 tests/diffusion/models/gr00t/test_pipeline.py create mode 100644 tests/diffusion/test_diffusion_engine_actions.py create mode 100644 tests/e2e/online_serving/test_gr00t_openpi.py create mode 100644 tests/gr00t/__init__.py create mode 100644 tests/gr00t/openpi_client_helper.py create mode 100644 vllm_omni/deploy/Gr00tN1d7.yaml create mode 100644 vllm_omni/diffusion/models/gr00t/__init__.py create mode 100644 vllm_omni/diffusion/models/gr00t/configs/__init__.py create mode 100644 vllm_omni/diffusion/models/gr00t/configs/embodiment/__init__.py create mode 100644 vllm_omni/diffusion/models/gr00t/configs/embodiment/embodiment_configs.py create mode 100644 vllm_omni/diffusion/models/gr00t/configs/model/__init__.py create mode 100644 vllm_omni/diffusion/models/gr00t/configs/model/gr00t_n1d7.py create mode 100644 vllm_omni/diffusion/models/gr00t/dataio/__init__.py create mode 100644 vllm_omni/diffusion/models/gr00t/dataio/collator/__init__.py create mode 100755 vllm_omni/diffusion/models/gr00t/dataio/collator/collators.py create mode 100755 vllm_omni/diffusion/models/gr00t/dataio/embodiment_tags.py create mode 100644 vllm_omni/diffusion/models/gr00t/dataio/interfaces.py create mode 100644 vllm_omni/diffusion/models/gr00t/dataio/state_action/__init__.py create mode 100644 vllm_omni/diffusion/models/gr00t/dataio/state_action/action_chunking.py create mode 100644 vllm_omni/diffusion/models/gr00t/dataio/state_action/pose.py create mode 100644 vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py create mode 100644 vllm_omni/diffusion/models/gr00t/dataio/types.py create mode 100644 vllm_omni/diffusion/models/gr00t/dataio/utils.py create mode 100644 vllm_omni/diffusion/models/gr00t/modeling/__init__.py create mode 100644 vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py create mode 100755 vllm_omni/diffusion/models/gr00t/modeling/image_augmentations.py create mode 100644 vllm_omni/diffusion/models/gr00t/modeling/modules/__init__.py create mode 100755 vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py create mode 100644 vllm_omni/diffusion/models/gr00t/modeling/modules/embodiment_conditioned_mlp.py create mode 100644 vllm_omni/diffusion/models/gr00t/modeling/modules/flowmatching_modules.py create mode 100644 vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py create mode 100644 vllm_omni/diffusion/models/gr00t/pipeline_gr00t.py create mode 100644 vllm_omni/diffusion/models/gr00t/policy.py create mode 100644 vllm_omni/diffusion/models/gr00t/policy_base.py create mode 100644 vllm_omni/model_executor/models/gr00t/__init__.py create mode 100644 vllm_omni/model_executor/models/gr00t/pipeline.py diff --git a/docs/.nav.yml b/docs/.nav.yml index 562a49e84f8..dc28cb1e49a 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -37,6 +37,7 @@ nav: - vLLM-Omni Helm Chart: user_guide/examples/online_serving/chart-helm.md - Diffusers Backend Adapter: user_guide/examples/online_serving/diffusers_pipeline_adapter.md - GLM-Image Online Serving: user_guide/examples/online_serving/glm_image.md + - GR00T OpenPI Serving: user_guide/examples/online_serving/gr00t.md - Image-To-Image: user_guide/examples/online_serving/image_to_image.md - Image-To-Video: user_guide/examples/online_serving/image_to_video.md - Online serving Example of vLLM-Omni for MiMo-Audio: user_guide/examples/online_serving/mimo_audio.md diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index f58123317de..3b1e60b47ae 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -22,6 +22,7 @@ th { | `MingFlashOmniForConditionalGeneration` + `MingImagePipeline` | Ming-flash-omni-2.0 (omni-speech + imagegen1) | `Jonathan1909/Ming-flash-omni-2.0` | ✅︎ | | | | | `BagelForConditionalGeneration` | BAGEL (DiT-only) | `ByteDance-Seed/BAGEL-7B-MoT` | ✅︎ | ✅︎ | | ✅︎ | | `InternVLAA1Pipeline` | InternVLA-A1 | `InternRobotics/InternVLA-A1-3B` | ✅︎ | ✅︎ | | | +| `Gr00tN1d7Pipeline` | GR00T N1.7 | `nvidia/GR00T-N1.7-3B` | ✅︎ | | | | | `HunyuanImage3ForCausalMM` | HunyuanImage3.0 (DiT-only) | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `QwenImagePipeline` | Qwen-Image | `Qwen/Qwen-Image` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `QwenImagePipeline` | Qwen-Image-2512 | `Qwen/Qwen-Image-2512` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | diff --git a/docs/user_guide/examples/online_serving/gr00t.md b/docs/user_guide/examples/online_serving/gr00t.md new file mode 100644 index 00000000000..b42de45b857 --- /dev/null +++ b/docs/user_guide/examples/online_serving/gr00t.md @@ -0,0 +1,26 @@ +# GR00T OpenPI Serving + +Source . + +GR00T N1.7 is served through `/v1/realtime/robot/openpi`. The endpoint uses the OpenPI msgpack-numpy websocket protocol and returns GR00T actions as `dict[str, np.ndarray]`. + +## Prerequisites + +Install `openpi-client` in the serving environment. The OpenPI endpoint uses `openpi_client.msgpack_numpy` to pack and unpack websocket payloads. + +## Start the server + +```bash +uv run --no-sync --with openpi-client vllm serve nvidia/GR00T-N1.7-3B \ + --omni \ + --stage-configs-path vllm_omni/deploy/Gr00tN1d7.yaml +``` + +The deploy config is `vllm_omni/deploy/Gr00tN1d7.yaml`. It registers `Gr00tN1d7Pipeline` and exposes `policy_server_config` for the OpenPI handshake. + +## Action output + +Unlike single-stream policies that return one ndarray, GR00T returns a per-action-key dictionary. vLLM-Omni preserves that dictionary under `multimodal_output["actions"]`, and the OpenPI endpoint sends it as the websocket success payload. + +??? abstract "Example README" + --8<-- "examples/online_serving/gr00t/README.md" diff --git a/examples/online_serving/gr00t/README.md b/examples/online_serving/gr00t/README.md new file mode 100644 index 00000000000..38db1a511a1 --- /dev/null +++ b/examples/online_serving/gr00t/README.md @@ -0,0 +1,33 @@ +# GR00T OpenPI Serving + +This example serves NVIDIA Isaac GR00T N1.7 through the OpenPI-compatible robot websocket endpoint. + +## Requirements + +- Install the OpenPI client dependency used by the websocket protocol: + +```bash +pip install openpi-client +``` + +## Start the server + +From the repository root: + +```bash +vllm serve nvidia/GR00T-N1.7-3B \ + --omni \ + --stage-configs-path vllm_omni/deploy/Gr00tN1d7.yaml +``` + +The websocket endpoint is: + +```text +ws://127.0.0.1:8000/v1/realtime/robot/openpi +``` + +## Request and response + +The OpenPI serving layer forwards raw robot observations through `sampling_params.extra_args["robot_obs"]`. The GR00T pipeline converts the observation to the local GR00T policy input shape and returns `multimodal_output["actions"]` as `dict[str, np.ndarray]`. + +The server handshake advertises the configured embodiment and action schema through `policy_server_config` in `vllm_omni/deploy/Gr00tN1d7.yaml`. diff --git a/requirements/common.txt b/requirements/common.txt index e649450afdb..1af2eaef55a 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -10,6 +10,7 @@ omegaconf>=2.3.0 diffusers==0.38.0 safetensors>=0.8.0rc0 accelerate==1.12.0 +albumentations==1.4.18 soundfile>=0.13.1 cache-dit==1.3.0 tqdm>=4.66.0 diff --git a/tests/diffusion/models/gr00t/__init__.py b/tests/diffusion/models/gr00t/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/diffusion/models/gr00t/test_pipeline.py b/tests/diffusion/models/gr00t/test_pipeline.py new file mode 100644 index 00000000000..dd75773ec3f --- /dev/null +++ b/tests/diffusion/models/gr00t/test_pipeline.py @@ -0,0 +1,142 @@ +from types import SimpleNamespace + +import numpy as np +import pytest + +from vllm_omni.diffusion.models.gr00t import pipeline_gr00t +from vllm_omni.diffusion.models.gr00t.pipeline_gr00t import Gr00tN1d7Pipeline +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class FakeGr00tPolicy: + instances = [] + + def __init__(self, **kwargs): + self.kwargs = kwargs + self.reset_calls = 0 + self.seen_obs = None + self.embodiment_tag = SimpleNamespace(value="fake_embodiment") + self.language_key = "annotation.language.language_instruction" + self.modality_configs = { + "action": SimpleNamespace( + delta_indices=[0, 1], + modality_keys=["arm", "gripper"], + ) + } + self.processor = SimpleNamespace( + state_action_processor=SimpleNamespace( + norm_params={ + "fake_embodiment": { + "action": { + "arm": {"dim": np.array(2)}, + "gripper": {"dim": np.array(1)}, + } + } + } + ) + ) + FakeGr00tPolicy.instances.append(self) + + def get_action(self, obs): + self.seen_obs = obs + return { + "arm": np.array([[[1.0, 2.0]]], dtype=np.float64), + "gripper": [[[3.0]]], + }, {"latency_ms": 1.0} + + def reset(self): + self.reset_calls += 1 + return {"reset": True} + + +@pytest.fixture(autouse=True) +def fake_gr00t_policy(monkeypatch): + FakeGr00tPolicy.instances.clear() + monkeypatch.setattr(pipeline_gr00t, "Gr00tPolicy", FakeGr00tPolicy) + + +def _pipeline(): + od_config = SimpleNamespace( + model="nvidia/GR00T-N1.7-3B", + model_config={ + "embodiment_tag": "LIBERO_PANDA", + "strict": False, + }, + custom_pipeline_args={}, + ) + return Gr00tN1d7Pipeline(od_config=od_config) + + +def test_pipeline_initializes_local_policy(): + pipeline = _pipeline() + + policy = FakeGr00tPolicy.instances[0] + assert policy.kwargs["model_path"] == "nvidia/GR00T-N1.7-3B" + assert policy.kwargs["embodiment_tag"] == "LIBERO_PANDA" + assert policy.kwargs["strict"] is False + assert pipeline.weights_sources == () + assert pipeline.load_weights(iter(())) == set() + + +def test_forward_returns_dict_actions_in_multimodal_output(): + pipeline = _pipeline() + req = OmniDiffusionRequest( + prompts=["pick"], + request_ids=["req"], + sampling_params=OmniDiffusionSamplingParams( + extra_args={ + "robot_obs": { + "images": {"cam": np.zeros((1, 1, 8, 8, 3), dtype=np.uint8)}, + "state": {"joint": np.zeros((1, 1, 2), dtype=np.float32)}, + "prompt": "pick the cube", + "session_id": "session-a", + }, + "reset": True, + } + ), + ) + + output = pipeline.forward(req) + + assert output.error is None + assert output.output is None + actions = output.multimodal_output["actions"] + assert set(actions) == {"arm", "gripper"} + assert actions["arm"].dtype == np.float32 + np.testing.assert_allclose(actions["arm"], np.array([[[1.0, 2.0]]], dtype=np.float32)) + policy = FakeGr00tPolicy.instances[0] + assert "video" in policy.seen_obs + assert policy.seen_obs["language"] == {"annotation.language.language_instruction": [["pick the cube"]]} + assert "images" not in policy.seen_obs + assert "prompt" not in policy.seen_obs + assert "session_id" not in policy.seen_obs + assert policy.reset_calls == 1 + + +def test_dummy_warmup_returns_shape_correct_zero_actions(): + pipeline = _pipeline() + req = OmniDiffusionRequest( + prompts=["dummy run"], + request_ids=["dummy_req_id"], + sampling_params=OmniDiffusionSamplingParams(num_inference_steps=1), + ) + + output = pipeline.forward(req) + + assert output.error is None + actions = output.multimodal_output["actions"] + assert set(actions) == {"arm", "gripper"} + assert actions["arm"].shape == (1, 2, 2) + assert actions["gripper"].shape == (1, 2, 1) + assert not actions["arm"].any() + assert FakeGr00tPolicy.instances[0].seen_obs is None + + +def test_reset_delegates_to_policy(): + pipeline = _pipeline() + + assert pipeline.reset() == {"reset": True} + assert FakeGr00tPolicy.instances[0].reset_calls == 1 diff --git a/tests/diffusion/test_diffusion_engine_actions.py b/tests/diffusion/test_diffusion_engine_actions.py new file mode 100644 index 00000000000..8408c417649 --- /dev/null +++ b/tests/diffusion/test_diffusion_engine_actions.py @@ -0,0 +1,71 @@ +import asyncio +import importlib.util +import types + +import numpy as np +import pytest + +from vllm_omni.diffusion.data import DiffusionOutput +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +HAS_VLLM_IR = importlib.util.find_spec("vllm.ir") is not None +if HAS_VLLM_IR: + from vllm_omni.diffusion.diffusion_engine import DiffusionEngine + +pytestmark = [ + pytest.mark.core_model, + pytest.mark.cpu, + pytest.mark.skipif(not HAS_VLLM_IR, reason="Installed vLLM does not provide vllm.ir"), +] + + +def _engine_with_output(output: DiffusionOutput): + engine = DiffusionEngine.__new__(DiffusionEngine) + engine.od_config = types.SimpleNamespace(enable_cpu_offload=False, model_class_name="Gr00tN1d7Pipeline") + engine.pre_process_func = None + engine.post_process_func = None + engine._post_process_accepts_sampling_params = False + + async def check_loop(): + return None + + async def run_request(request): + return output + + engine._check_and_start_background_loop = check_loop + engine.async_add_req_and_wait_for_response = run_request + return engine + + +def test_diffusion_engine_surfaces_action_multimodal_output(): + actions = {"arm": np.array([[[1.0, 2.0]]], dtype=np.float32)} + engine = _engine_with_output(DiffusionOutput(multimodal_output={"actions": actions})) + req = OmniDiffusionRequest( + prompts=["pick"], + request_ids=["req"], + sampling_params=OmniDiffusionSamplingParams(), + ) + + outputs = asyncio.run(engine.step(req)) + + assert len(outputs) == 1 + assert outputs[0].images == [] + assert outputs[0].final_output_type == "actions" + assert outputs[0].multimodal_output["actions"] is actions + + +def test_diffusion_engine_surfaces_actions_from_output_dict(): + actions = {"arm": np.array([[[1.0, 2.0]]], dtype=np.float32)} + engine = _engine_with_output(DiffusionOutput(output={"actions": actions})) + req = OmniDiffusionRequest( + prompts=["pick"], + request_ids=["req"], + sampling_params=OmniDiffusionSamplingParams(), + ) + + outputs = asyncio.run(engine.step(req)) + + assert outputs[0].images == [] + assert outputs[0].final_output_type == "actions" + assert outputs[0].multimodal_output["actions"] is actions diff --git a/tests/e2e/online_serving/test_gr00t_openpi.py b/tests/e2e/online_serving/test_gr00t_openpi.py new file mode 100644 index 00000000000..c40137232b9 --- /dev/null +++ b/tests/e2e/online_serving/test_gr00t_openpi.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""End-to-end online serving test for GR00T N1.7 through the OpenPI robot endpoint.""" + +import os + +import pytest + +from tests.gr00t import openpi_client_helper as openpi_client +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServerParams +from tests.helpers.stage_config import get_deploy_config_path + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +MODEL = "nvidia/GR00T-N1.7-3B" +openpi_client.require_dependencies() + +test_params = [ + pytest.param( + OmniServerParams( + model=MODEL, + stage_config_path=get_deploy_config_path("Gr00tN1d7.yaml"), + server_args=["--disable-log-stats"], + env_dict={"VLLM_DISABLE_COMPILE_CACHE": "1"}, + init_timeout=1200, + stage_init_timeout=900, + ), + id="gr00t-n1d7-openpi", + ) +] + + +@pytest.mark.advanced_model +@pytest.mark.diffusion +@hardware_test(res={"cuda": "H100"}, num_cards=1) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_gr00t_n1d7_openpi_online(omni_server) -> None: + result = openpi_client.run_policy_session( + host=omni_server.host, + port=omni_server.port, + session_id="gr00t-online-e2e", + ) + openpi_client.validate_session_result(result) diff --git a/tests/entrypoints/openai_api/test_openpi_serving.py b/tests/entrypoints/openai_api/test_openpi_serving.py index 9eb1d5bfe0b..597ba9f3028 100644 --- a/tests/entrypoints/openai_api/test_openpi_serving.py +++ b/tests/entrypoints/openai_api/test_openpi_serving.py @@ -2,6 +2,7 @@ import json import threading from concurrent.futures import ThreadPoolExecutor +from pathlib import Path from types import SimpleNamespace import numpy as np @@ -24,6 +25,8 @@ "action_space": "joint_position", } +GR00T_DEPLOY_CONFIG = Path(__file__).resolve().parents[3] / "vllm_omni" / "deploy" / "Gr00tN1d7.yaml" + def _json_default(obj): if isinstance(obj, np.ndarray): @@ -101,6 +104,13 @@ async def _generate(): return _generate() +def test_gr00t_deploy_reports_droid_image_resolution(): + config = OmegaConf.load(GR00T_DEPLOY_CONFIG) + policy_config = config.stages[0].model_config.policy_server_config + + assert list(policy_config.image_resolution) == [180, 320] + + def test_policy_server_config_reads_diffusion_model_config(): policy_config = { "image_resolution": [64, 64], diff --git a/tests/gr00t/__init__.py b/tests/gr00t/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/tests/gr00t/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/tests/gr00t/openpi_client_helper.py b/tests/gr00t/openpi_client_helper.py new file mode 100644 index 00000000000..e5943ca3307 --- /dev/null +++ b/tests/gr00t/openpi_client_helper.py @@ -0,0 +1,183 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import Any + +import numpy as np + +try: + import websockets.sync.client as websockets_client +except ImportError: # pragma: no cover - optional e2e dependency + websockets_client = None + +try: + from openpi_client import msgpack_numpy +except ImportError: # pragma: no cover - optional e2e dependency + msgpack_numpy = None + +PING_INTERVAL_SECS = 300 +PING_TIMEOUT_SECS = 3600 +DEFAULT_HOST = "127.0.0.1" +DEFAULT_PORT = 8000 +DEFAULT_PATH = "/v1/realtime/robot/openpi" +DEFAULT_SESSION_ID = "gr00t-smoke" +ACTION_KEYS = {"eef_9d", "gripper_position", "joint_position"} +LANGUAGE_KEY = "annotation.language.language_instruction" + + +def _identity_eef_9d_state() -> np.ndarray: + state = np.zeros((1, 1, 9), dtype=np.float32) + state[..., 3:] = np.array([1, 0, 0, 0, 1, 0], dtype=np.float32) + return state + + +def require_dependencies() -> None: + missing = [] + if websockets_client is None: + missing.append("websockets") + if msgpack_numpy is None: + missing.append("openpi-client") + if missing: + raise ModuleNotFoundError(f"GR00T OpenPI test dependencies are missing: {', '.join(missing)}") + + +@dataclass(frozen=True) +class Gr00tServerMetadata: + action_horizon: int + action_keys: set[str] + embodiment_tag: str + needs_session_id: bool + + @classmethod + def from_dict(cls, payload: dict[str, Any]): + required_keys = ("action_horizon", "action_keys", "embodiment_tag", "needs_session_id") + missing_keys = [key for key in required_keys if key not in payload] + if missing_keys: + raise ValueError(f"Missing GR00T metadata keys: {missing_keys}") + + return cls( + action_horizon=int(payload["action_horizon"]), + action_keys={str(key) for key in payload["action_keys"]}, + embodiment_tag=str(payload["embodiment_tag"]), + needs_session_id=bool(payload["needs_session_id"]), + ) + + +class OpenPIWebsocketClient: + def __init__( + self, + *, + host: str = DEFAULT_HOST, + port: int = DEFAULT_PORT, + path: str = DEFAULT_PATH, + ) -> None: + require_dependencies() + self._uri = f"ws://{host}:{port}{path}" + self._packer = msgpack_numpy.Packer() + self._ws, self._server_metadata = self._connect() + + def _connect(self): + conn = websockets_client.connect( + self._uri, + compression=None, + max_size=None, + ping_interval=PING_INTERVAL_SECS, + ping_timeout=PING_TIMEOUT_SECS, + ) + metadata = msgpack_numpy.unpackb(conn.recv()) + if not isinstance(metadata, dict): + raise TypeError(f"Expected dict metadata from server, got {type(metadata)!r}") + return conn, metadata + + def get_server_metadata(self) -> dict[str, Any]: + return dict(self._server_metadata) + + def infer(self, obs: dict[str, Any]) -> dict[str, np.ndarray]: + payload = dict(obs) + payload["endpoint"] = "infer" + self._ws.send(self._packer.pack(payload)) + response = msgpack_numpy.unpackb(self._ws.recv()) + if isinstance(response, dict) and response.get("type") == "error": + raise RuntimeError(f"Inference failed: {response['message']}") + if not isinstance(response, dict): + raise RuntimeError(f"Expected dict actions from GR00T OpenPI endpoint, got {type(response)!r}") + return {str(key): np.asarray(value, dtype=np.float32) for key, value in response.items()} + + def reset(self, reset_info: dict[str, Any] | None = None) -> str: + payload = dict(reset_info or {}) + payload["endpoint"] = "reset" + self._ws.send(self._packer.pack(payload)) + response = msgpack_numpy.unpackb(self._ws.recv()) + if not isinstance(response, dict) or response.get("status") != "reset successful": + raise RuntimeError(f"Unexpected reset response: {response!r}") + return str(response["status"]) + + def close(self) -> None: + self._ws.close() + + +def build_droid_observation(*, session_id: str = DEFAULT_SESSION_ID) -> dict[str, Any]: + return { + "session_id": session_id, + "video": { + "exterior_image_1_left": np.zeros((1, 2, 256, 256, 3), dtype=np.uint8), + "wrist_image_left": np.zeros((1, 2, 256, 256, 3), dtype=np.uint8), + }, + "state": { + "eef_9d": _identity_eef_9d_state(), + "gripper_position": np.zeros((1, 1, 1), dtype=np.float32), + "joint_position": np.zeros((1, 1, 7), dtype=np.float32), + }, + "language": {LANGUAGE_KEY: [["pick up the object"]]}, + } + + +def run_policy_session( + *, + host: str = DEFAULT_HOST, + port: int = DEFAULT_PORT, + path: str = DEFAULT_PATH, + session_id: str = DEFAULT_SESSION_ID, +) -> dict[str, Any]: + client = OpenPIWebsocketClient(host=host, port=port, path=path) + try: + metadata = client.get_server_metadata() + actions = client.infer(build_droid_observation(session_id=session_id)) + reset_status = client.reset({}) + return { + "metadata": metadata, + "actions": actions, + "reset_status": reset_status, + "session_id": session_id, + } + finally: + client.close() + + +def validate_session_result(result: dict[str, Any]) -> None: + metadata = Gr00tServerMetadata.from_dict(result["metadata"]) + if metadata.embodiment_tag != "OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT": + raise AssertionError(f"Unexpected embodiment_tag: {metadata.embodiment_tag}") + if not metadata.needs_session_id: + raise AssertionError("GR00T test expects needs_session_id metadata") + if metadata.action_keys != ACTION_KEYS: + raise AssertionError(f"Unexpected action keys: {metadata.action_keys}") + if result["reset_status"] != "reset successful": + raise AssertionError(f"Unexpected reset status: {result['reset_status']!r}") + + actions = result["actions"] + if set(actions) != ACTION_KEYS: + raise AssertionError(f"Unexpected action keys: {set(actions)}") + expected_shapes = { + "eef_9d": (1, metadata.action_horizon, 9), + "gripper_position": (1, metadata.action_horizon, 1), + "joint_position": (1, metadata.action_horizon, 7), + } + for key, expected_shape in expected_shapes.items(): + if actions[key].shape != expected_shape: + raise AssertionError(f"Action {key} shape mismatch: expected {expected_shape}, got {actions[key].shape}") + if actions[key].dtype != np.float32: + raise AssertionError(f"Action {key} dtype mismatch: expected float32, got {actions[key].dtype}") + if not np.isfinite(actions[key]).all(): + raise AssertionError(f"Action {key} contains non-finite values") diff --git a/vllm_omni/config/pipeline_registry.py b/vllm_omni/config/pipeline_registry.py index 140bddf45ae..c375cc38d60 100644 --- a/vllm_omni/config/pipeline_registry.py +++ b/vllm_omni/config/pipeline_registry.py @@ -28,8 +28,6 @@ the entries declared here. """ -from __future__ import annotations - # --- Multi-stage omni pipelines (LLM-centric; audio / video I/O) --- _OMNI_PIPELINES: dict[str, tuple[str, str]] = { # model_type -> (module_path, variable_name) @@ -69,6 +67,10 @@ "vllm_omni.model_executor.models.glm_image.pipeline", "GLM_IMAGE_PIPELINE", ), + "Gr00tN1d7": ( + "vllm_omni.model_executor.models.gr00t.pipeline", + "GR00T_N1D7_PIPELINE", + ), "hunyuan_image_3_moe": ( "vllm_omni.model_executor.models.hunyuan_image3.pipeline", "HUNYUAN_IMAGE3_PIPELINE", diff --git a/vllm_omni/deploy/Gr00tN1d7.yaml b/vllm_omni/deploy/Gr00tN1d7.yaml new file mode 100644 index 00000000000..8ebb28d5163 --- /dev/null +++ b/vllm_omni/deploy/Gr00tN1d7.yaml @@ -0,0 +1,40 @@ +# GR00T N1.7 deploy: single diffusion stage. +# +# Topology is declared in vllm_omni/model_executor/models/gr00t/pipeline.py. +# This default uses one GPU with TP=1 and CFG parallel disabled. + +pipeline: Gr00tN1d7 +async_chunk: false +distributed_executor_backend: mp +dtype: bfloat16 + +stages: + - stage_id: 0 + devices: "0" + max_num_seqs: 1 + enforce_eager: true + model_class_name: Gr00tN1d7Pipeline + model_config: + embodiment_tag: OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT + strict: true + policy_server_config: + image_resolution: [180, 320] + needs_session_id: true + embodiment_tag: OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT + action_horizon: 40 + action_keys: + - eef_9d + - gripper_position + - joint_position + supported_embodiments: + - OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT + - XDOF + - XDOF_SUBTASK + - REAL_G1 + - REAL_R1_PRO_SHARPA + - REAL_R1_PRO_SHARPA_HUMAN + - REAL_R1_PRO_SHARPA_MAXINSIGHTS + - REAL_R1_PRO_SHARPA_MECKA + - LIBERO_PANDA + - SIMPLER_ENV_GOOGLE + - SIMPLER_ENV_WIDOWX diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 5fe2d4061cc..01026f35ebd 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -935,6 +935,10 @@ def enrich_config(self) -> None: self.model_class_name = "WanS2VPipeline" self.tf_model_config = TransformerConfig() self.update_multimodal_support() + elif model_type == "Gr00tN1d7" or "Gr00tN1d7" in architectures: + self.model_class_name = "Gr00tN1d7Pipeline" + self.set_tf_model_config(TransformerConfig()) + self.update_multimodal_support() elif architectures and len(architectures) == 1: self.model_class_name = architectures[0] else: @@ -1006,10 +1010,10 @@ class DiffusionOutput: """ # Fields may be replaced with SHM handle dicts by ipc.pack_diffusion_output_shm - output: torch.Tensor | dict | None = None - trajectory_timesteps: torch.Tensor | dict | None = None - trajectory_latents: torch.Tensor | dict | None = None - trajectory_log_probs: torch.Tensor | dict | None = None + output: torch.Tensor | tuple[Any, ...] | dict[str, Any] | None = None + trajectory_timesteps: torch.Tensor | dict[str, Any] | None = None + trajectory_latents: torch.Tensor | dict[str, Any] | None = None + trajectory_log_probs: torch.Tensor | dict[str, Any] | None = None trajectory_decoded: list[Image.Image] | None = None error: str | None = None aborted: bool = False @@ -1017,6 +1021,11 @@ class DiffusionOutput: post_process_func: Callable[..., Any] | None = None + # Multimodal payloads that should be surfaced on OmniRequestOutput. + # Robot policies use this for `{"actions": ...}` without pretending the + # action dict is an image/video output. + multimodal_output: dict[str, Any] = field(default_factory=dict) + # Extra custom output data (e.g. latent trajectories, prompt embeds) # passed through to OmniRequestOutput.custom_output custom_output: dict[str, Any] = field(default_factory=dict) diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 1805b5dc349..d255f8b22d9 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import asyncio import concurrent.futures import inspect @@ -220,7 +218,8 @@ async def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: raise RuntimeError(output.error) logger.debug("Generation completed successfully.") - if output.output is None: + base_multimodal_output = output.multimodal_output + if output.output is None and not base_multimodal_output: logger.warning("Output is None, returning empty OmniRequestOutput") return [ OmniRequestOutput.from_diffusion( @@ -258,12 +257,18 @@ async def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: custom_output = output.custom_output or {} model_audio_sample_rate = None model_fps = None + non_visual_output_keys = {"audio", "actions", "custom_output", "audio_sample_rate", "fps"} if isinstance(outputs, dict): audio_payload = outputs.get("audio") + if "actions" in outputs: + base_multimodal_output["actions"] = outputs["actions"] custom_output.update(outputs.get("custom_output") or {}) model_audio_sample_rate = outputs.get("audio_sample_rate") model_fps = outputs.get("fps") - outputs = outputs.get("video", outputs) + if "video" in outputs: + outputs = outputs["video"] + elif outputs.keys() <= non_visual_output_keys: + outputs = None postprocess_time = time.perf_counter() - postprocess_start_time logger.debug("Post-processing completed in %.4f seconds", postprocess_time) @@ -346,13 +351,14 @@ def _audio_mm(payload: Any) -> dict[str, Any]: ), ] else: - mm_output = {} + mm_output = base_multimodal_output.copy() if audio_payload is not None: mm_output["audio"] = audio_payload if model_audio_sample_rate is not None: mm_output["audio_sample_rate"] = model_audio_sample_rate if model_fps is not None: mm_output["fps"] = model_fps + final_output_type = "actions" if "actions" in mm_output and not outputs else "image" return [ OmniRequestOutput.from_diffusion( request_id=request_id, @@ -366,6 +372,7 @@ def _audio_mm(payload: Any) -> dict[str, Any]: trajectory_decoded=output.trajectory_decoded, custom_output=custom_output, multimodal_output=mm_output, + final_output_type=final_output_type, stage_durations=output.stage_durations, peak_memory_mb=output.peak_memory_mb, ), @@ -405,7 +412,7 @@ def _audio_mm(payload: Any) -> dict[str, Any]: ), ) else: - mm_output = {} + mm_output = base_multimodal_output.copy() if audio_payload is not None: sliced_audio = audio_payload if isinstance(audio_payload, (list, tuple)): @@ -422,6 +429,7 @@ def _audio_mm(payload: Any) -> dict[str, Any]: mm_output["audio_sample_rate"] = model_audio_sample_rate if model_fps is not None: mm_output["fps"] = model_fps + final_output_type = "actions" if "actions" in mm_output and not request_outputs else "image" results.append( OmniRequestOutput.from_diffusion( request_id=request_id, @@ -435,6 +443,7 @@ def _audio_mm(payload: Any) -> dict[str, Any]: trajectory_decoded=output.trajectory_decoded, custom_output=custom_output, multimodal_output=mm_output, + final_output_type=final_output_type, stage_durations=output.stage_durations, peak_memory_mb=output.peak_memory_mb, ), @@ -569,7 +578,7 @@ def _handle_finished_requests( def make_engine( config: OmniDiffusionConfig, scheduler: SchedulerInterface | None = None, - ) -> DiffusionEngine: + ) -> "DiffusionEngine": """Factory method to create a DiffusionEngine instance. Args: diff --git a/vllm_omni/diffusion/models/gr00t/__init__.py b/vllm_omni/diffusion/models/gr00t/__init__.py new file mode 100644 index 00000000000..5a93bd798c6 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.diffusion.models.gr00t.pipeline_gr00t import ( + Gr00tN1d7Pipeline, + get_gr00t_n1d7_post_process_func, +) + +__all__ = ["Gr00tN1d7Pipeline", "get_gr00t_n1d7_post_process_func"] diff --git a/vllm_omni/diffusion/models/gr00t/configs/__init__.py b/vllm_omni/diffusion/models/gr00t/configs/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/configs/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/diffusion/models/gr00t/configs/embodiment/__init__.py b/vllm_omni/diffusion/models/gr00t/configs/embodiment/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/configs/embodiment/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/diffusion/models/gr00t/configs/embodiment/embodiment_configs.py b/vllm_omni/diffusion/models/gr00t/configs/embodiment/embodiment_configs.py new file mode 100644 index 00000000000..c7754c2ac9a --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/configs/embodiment/embodiment_configs.py @@ -0,0 +1,251 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vllm_omni.diffusion.models.gr00t.dataio.embodiment_tags import EmbodimentTag +from vllm_omni.diffusion.models.gr00t.dataio.types import ( + ActionConfig, + ActionFormat, + ActionRepresentation, + ActionType, + ModalityConfig, +) + +MODALITY_CONFIGS = { + ##### Pre-registered pretrain configurations ##### + "oxe_droid_relative_eef_relative_joint": { + "video": ModalityConfig( + delta_indices=[-15, 0], + modality_keys=["exterior_image_1_left", "wrist_image_left"], + ), + "state": ModalityConfig( + delta_indices=[0], + modality_keys=["eef_9d", "gripper_position", "joint_position"], + ), + "action": ModalityConfig( + delta_indices=list(range(40)), + modality_keys=["eef_9d", "gripper_position", "joint_position"], + action_configs=[ + ActionConfig( + rep=ActionRepresentation.RELATIVE, + type=ActionType.EEF, + format=ActionFormat.XYZ_ROT6D, + state_key="eef_9d", + ), + ActionConfig( + rep=ActionRepresentation.ABSOLUTE, + type=ActionType.NON_EEF, + format=ActionFormat.DEFAULT, + state_key="gripper_position", + ), + ActionConfig( + rep=ActionRepresentation.RELATIVE, + type=ActionType.NON_EEF, + format=ActionFormat.DEFAULT, + state_key="joint_position", + ), + ], + ), + "language": ModalityConfig( + delta_indices=[0], + modality_keys=["annotation.language.language_instruction"], + ), + }, + ##### Pre-registered posttrain configurations ##### + "unitree_g1_sonic": { + "video": ModalityConfig( + delta_indices=[0], + modality_keys=["ego_view"], + ), + "state": ModalityConfig( + delta_indices=[0], + modality_keys=[ + "left_leg", + "right_leg", + "waist", + "left_arm", + "right_arm", + "left_hand", + "right_hand", + "projected_gravity", + ], + ), + "action": ModalityConfig( + delta_indices=list(range(40)), + modality_keys=[ + "motion_token", + "left_hand_joints", + "right_hand_joints", + ], + action_configs=[ + ActionConfig( + rep=ActionRepresentation.ABSOLUTE, + type=ActionType.NON_EEF, + format=ActionFormat.DEFAULT, + ), + ActionConfig( + rep=ActionRepresentation.ABSOLUTE, + type=ActionType.NON_EEF, + format=ActionFormat.DEFAULT, + ), + ActionConfig( + rep=ActionRepresentation.ABSOLUTE, + type=ActionType.NON_EEF, + format=ActionFormat.DEFAULT, + ), + ], + ), + "language": ModalityConfig( + delta_indices=[0], + modality_keys=["annotation.human.task_description"], + ), + }, + "unitree_g1_full_body_with_waist_height_nav_cmd": { + "video": ModalityConfig( + delta_indices=[0], + modality_keys=["ego_view"], + ), + "state": ModalityConfig( + delta_indices=[0], + modality_keys=[ + "left_leg", + "right_leg", + "waist", + "left_arm", + "right_arm", + "left_hand", + "right_hand", + ], + ), + "action": ModalityConfig( + delta_indices=list(range(50)), + modality_keys=[ + "left_arm", + "right_arm", + "left_hand", + "right_hand", + "waist", + "base_height_command", + "navigate_command", + ], + action_configs=[ + # left_arm + ActionConfig( + rep=ActionRepresentation.RELATIVE, + type=ActionType.NON_EEF, + format=ActionFormat.DEFAULT, + ), + # right_arm + ActionConfig( + rep=ActionRepresentation.RELATIVE, + type=ActionType.NON_EEF, + format=ActionFormat.DEFAULT, + ), + # left_hand + ActionConfig( + rep=ActionRepresentation.ABSOLUTE, # G1 hand is controlled by binary signals like a gripper + type=ActionType.NON_EEF, + format=ActionFormat.DEFAULT, + ), + # right_hand + ActionConfig( + rep=ActionRepresentation.ABSOLUTE, # G1 hand is controlled by binary signals like a gripper + type=ActionType.NON_EEF, + format=ActionFormat.DEFAULT, + ), + # waist + ActionConfig( + rep=ActionRepresentation.ABSOLUTE, + type=ActionType.NON_EEF, + format=ActionFormat.DEFAULT, + ), + # base_height_command + ActionConfig( + rep=ActionRepresentation.ABSOLUTE, + type=ActionType.NON_EEF, + format=ActionFormat.DEFAULT, + ), + # navigate_command + ActionConfig( + rep=ActionRepresentation.ABSOLUTE, + type=ActionType.NON_EEF, + format=ActionFormat.DEFAULT, + ), + ], + ), + "language": ModalityConfig( + delta_indices=[0], + modality_keys=["annotation.human.task_description"], + ), + }, + "libero_sim": { + "video": ModalityConfig( + delta_indices=[0], + modality_keys=["image", "wrist_image"], + ), + "state": ModalityConfig( + delta_indices=[0], + modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "gripper"], + ), + "action": ModalityConfig( + delta_indices=list(range(16)), + modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "gripper"], + ), + "language": ModalityConfig( + delta_indices=[0], + modality_keys=["annotation.human.action.task_description"], + ), + }, + "simpler_env_widowx": { + "video": ModalityConfig( + delta_indices=[0], + modality_keys=["image_0"], + ), + "state": ModalityConfig( + delta_indices=[0], + modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"], + ), + "action": ModalityConfig( + delta_indices=list(range(8)), + modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "gripper"], + ), + "language": ModalityConfig( + delta_indices=[0], + modality_keys=["annotation.human.action.task_description"], + ), + }, + "simpler_env_google": { + "video": ModalityConfig( + delta_indices=[0], + modality_keys=["image"], + ), + "state": ModalityConfig( + delta_indices=[0], + modality_keys=["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"], + ), + "action": ModalityConfig( + delta_indices=list(range(8)), + modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "gripper"], + ), + "language": ModalityConfig( + delta_indices=[0], + modality_keys=["annotation.human.action.task_description"], + ), + }, +} + + +def register_modality_config(config: dict, embodiment_tag: EmbodimentTag = EmbodimentTag.NEW_EMBODIMENT): + assert embodiment_tag.value not in MODALITY_CONFIGS, f"Embodiment tag {embodiment_tag} already registered" + MODALITY_CONFIGS[embodiment_tag.value] = config diff --git a/vllm_omni/diffusion/models/gr00t/configs/model/__init__.py b/vllm_omni/diffusion/models/gr00t/configs/model/__init__.py new file mode 100644 index 00000000000..9648b859751 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/configs/model/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +MODEL_CONFIG_TYPES: dict[str, type] = {} + + +def register_model_config(shortname: str, configtype: type) -> None: + MODEL_CONFIG_TYPES[shortname] = configtype diff --git a/vllm_omni/diffusion/models/gr00t/configs/model/gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/configs/model/gr00t_n1d7.py new file mode 100644 index 00000000000..8b67958d330 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/configs/model/gr00t_n1d7.py @@ -0,0 +1,184 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from dataclasses import MISSING, asdict, dataclass, is_dataclass +from enum import Enum +from pathlib import Path + +import torch +from transformers import PretrainedConfig + +from . import register_model_config + + +def _default_diffusion_model_cfg() -> dict: + return { + "positional_embeddings": None, + "num_layers": 16, + "num_attention_heads": 32, + "attention_head_dim": 48, + "norm_type": "ada_norm", + "dropout": 0.2, + "final_dropout": True, + "output_dim": 1024, + "interleave_self_attention": True, + } + + +@dataclass +class Gr00tN1d7Config(PretrainedConfig): + """Unified configuration for Gr00tN1d7 model with backbone and action head. + + Gr00tN1d7 uses the Cosmos-Reason2-2B (Qwen3-VL architecture) VLM backbone, + replacing the Eagle backbone used in Gr00tN1d6. + """ + + # Model identification + model_type: str = "Gr00tN1d7" + model_dtype: str = "bfloat16" # Use bfloat16 for Flash Attention compatibility + + # Backbone configuration + model_name: str = "nvidia/Cosmos-Reason2-2B" + backbone_model_type: str = "qwen" + model_revision: str | None = None + tune_top_llm_layers: int = 0 # Number of top LLM layers to tune + backbone_embedding_dim: int = 2048 # project_to_dim; must match Cosmos-Reason2-2B hidden size + tune_llm: bool = False + tune_visual: bool = False + select_layer: int = 12 + reproject_vision: bool = False + use_flash_attention: bool = True + load_bf16: bool = False # Enable BF16 loading + backbone_trainable_params_fp32: bool = True + + ### Processing parameters + image_crop_size: tuple[int, int] | None = (230, 230) + image_target_size: tuple[int, int] | None = (256, 256) + + shortest_image_edge: int | None = None + crop_fraction: float | None = None + + random_rotation_angle: int | None = None + color_jitter_params: dict[str, float] | None = None + use_albumentations_transforms: bool = True + # Extra augmentation config (mask-based and others). + extra_augmentation_config: dict | None = None + formalize_language: bool = True + apply_sincos_state_encoding: bool = False # Global flag to enable per-embodiment sin/cos encoding + use_percentiles: bool = True + use_relative_action: bool = False + + # Action head configuration parameters + max_state_dim: int = 132 # Default from state_shape + max_action_dim: int = 132 # Default from action_shape + action_horizon: int = 40 + hidden_size: int = 1024 + input_embedding_dim: int = 1536 + + # State history: number of consecutive state timesteps fed to the state encoder + state_history_length: int = 1 + + # Global parameters + add_pos_embed: bool = True + attn_dropout: float = 0.2 + use_vlln: bool = True + max_seq_len: int = 1024 + use_alternate_vl_dit: bool = True # True for AlternateVLDiT, False for DiT + attend_text_every_n_blocks: int = 2 + + diffusion_model_cfg: dict | None = None + + # Flow matching parameters + num_inference_timesteps: int = 4 + noise_beta_alpha: float = 1.5 + noise_beta_beta: float = 1.0 + noise_s: float = 0.999 + num_timestep_buckets: int = 1000 + + # Training parameters + tune_projector: bool = True + tune_diffusion_model: bool = True + tune_vlln: bool = True + + # State augmentation parameters + state_dropout_prob: float = 0.8 # State dropout probability + exclude_state: bool = False # Zero out all state inputs (ablation) + use_mean_std: bool = False # Use mean/std normalization instead of min/max + + # Multi-embodiment parameters + max_num_embodiments: int = 32 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + for key, value in kwargs.items(): + setattr(self, key, value) + + # Ensures that all dataclass defaults (including those using default_factory) + # are explicitly assigned to the instance, even if dataclasses initialization or subclassing + # (PretrainedConfig) interferes with normal default injection. + for f in self.__dataclass_fields__.values(): + if not hasattr(self, f.name): + if f.default is not MISSING: + setattr(self, f.name, f.default) + elif getattr(f, "default_factory", MISSING) is not MISSING: + setattr(self, f.name, f.default_factory()) + + if self.diffusion_model_cfg is None: + self.diffusion_model_cfg = _default_diffusion_model_cfg() + else: + self.diffusion_model_cfg = dict(self.diffusion_model_cfg) + + def to_filtered_dict(self, exclude_augment: bool = True) -> dict: + """Return a dictionary representation of this config, optionally excluding augmentation keys.""" + if is_dataclass(self): + cfg = asdict(self) + else: + cfg = dict(self.__dict__) + + if exclude_augment: + exclude_keys = { + "random_rotation_angle", + "color_jitter_params", + "use_albumentations_transforms", + "formalize_language", + "image_crop_size", + "image_target_size", + "shortest_image_edge", + "crop_fraction", + } + cfg = {k: v for k, v in cfg.items() if k not in exclude_keys} + + return cfg + + def to_filtered_json(self, exclude_augment: bool = True, **kwargs) -> str: + """Return a JSON string of this config, optionally excluding augmentation keys.""" + + def default(o): + if isinstance(o, (Path, torch.dtype, torch.device)): + return str(o) + if isinstance(o, Enum): + return o.value + return str(o) + + return json.dumps( + self.to_filtered_dict(exclude_augment), + indent=2, + default=default, + **kwargs, + ) + + +register_model_config("Gr00tN1d7", Gr00tN1d7Config) diff --git a/vllm_omni/diffusion/models/gr00t/dataio/__init__.py b/vllm_omni/diffusion/models/gr00t/dataio/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/diffusion/models/gr00t/dataio/collator/__init__.py b/vllm_omni/diffusion/models/gr00t/dataio/collator/__init__.py new file mode 100644 index 00000000000..b39b02e3042 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/collator/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.diffusion.models.gr00t.dataio.collator.collators import BasicDataCollator + +__all__ = ["BasicDataCollator"] diff --git a/vllm_omni/diffusion/models/gr00t/dataio/collator/collators.py b/vllm_omni/diffusion/models/gr00t/dataio/collator/collators.py new file mode 100755 index 00000000000..17a4c2ad474 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/collator/collators.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + + +class BasicDataCollator: + def __call__(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]: + fields = features[0].keys() + batch = {} + for key in fields: + batch[key] = torch.stack([item[key] for item in features]) + return batch diff --git a/vllm_omni/diffusion/models/gr00t/dataio/embodiment_tags.py b/vllm_omni/diffusion/models/gr00t/dataio/embodiment_tags.py new file mode 100755 index 00000000000..939d52c1c01 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/embodiment_tags.py @@ -0,0 +1,207 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + +""" +Embodiment tags are used to identify the robot embodiment in the data. + +Naming convention: +_ + +If using multiple datasets, e.g. sim GR1 and real GR1, we can drop the dataset name and use only the robot name. +""" + + +class EmbodimentTag(Enum): + """Embodiment tags supported by the GR00T N1.7 checkpoint. + + Pretrain tags (baked into the base model nvidia/GR00T-N1.7-3B, inference-ready): + - OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT -> "oxe_droid_relative_eef_relative_joint" + - XDOF -> "xdof_relative_eef_relative_joint" + - XDOF_SUBTASK -> "xdof_relative_eef_relative_joint_subtask" + - REAL_G1 -> "real_g1_relative_eef_relative_joints" + - REAL_R1_PRO_SHARPA -> "real_r1_pro_sharpa_relative_eef" + - REAL_R1_PRO_SHARPA_HUMAN -> "real_r1_pro_sharpa_relative_eef_human" + - REAL_R1_PRO_SHARPA_MAXINSIGHTS -> "real_r1_pro_sharpa_relative_eef_maxinsights" + - REAL_R1_PRO_SHARPA_MECKA -> "real_r1_pro_sharpa_relative_eef_mecka" + + Pre-registered posttrain tags (require finetuned checkpoint): + - UNITREE_G1 -> "unitree_g1_full_body_with_waist_height_nav_cmd" + - UNITREE_G1_SONIC -> "unitree_g1_sonic" + - SIMPLER_ENV_GOOGLE -> "simpler_env_google" + - SIMPLER_ENV_WIDOWX -> "simpler_env_widowx" + - LIBERO_PANDA -> "libero_sim" + + Finetuning tag (for custom robots): + - NEW_EMBODIMENT -> "new_embodiment" + + Use ``EmbodimentTag.resolve(s)`` to look up a tag by name or value, + case-insensitively. + """ + + ##### Pretrain embodiment tags (in base model processor_config.json) ##### + + OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT = "oxe_droid_relative_eef_relative_joint" + """ + The Open-X-Embodiment DROID robot with relative EEF and relative joint position actions. + """ + + XDOF = "xdof_relative_eef_relative_joint" + """ + The generic X-DOF robot with relative EEF and relative joint position actions. + """ + + XDOF_SUBTASK = "xdof_relative_eef_relative_joint_subtask" + """ + The generic X-DOF robot (subtask variant). + """ + + REAL_G1 = "real_g1_relative_eef_relative_joints" + """ + Real-world Unitree G1 with relative EEF and relative joint actions. + """ + + REAL_R1_PRO_SHARPA = "real_r1_pro_sharpa_relative_eef" + """ + Real-world R1 Pro Sharpa with relative EEF actions. + """ + + REAL_R1_PRO_SHARPA_HUMAN = "real_r1_pro_sharpa_relative_eef_human" + """ + Real-world R1 Pro Sharpa with relative EEF actions (human teleop data). + """ + + REAL_R1_PRO_SHARPA_MAXINSIGHTS = "real_r1_pro_sharpa_relative_eef_maxinsights" + """ + Real-world R1 Pro Sharpa with relative EEF actions (MaxInsights data, single-cam). + """ + + REAL_R1_PRO_SHARPA_MECKA = "real_r1_pro_sharpa_relative_eef_mecka" + """ + Real-world R1 Pro Sharpa with relative EEF actions (Mecka data, single-cam). + """ + + ##### Pre-registered posttrain embodiment tags ##### + + UNITREE_G1 = "unitree_g1_full_body_with_waist_height_nav_cmd" + """ + The Unitree G1 robot (sim, full-body with waist height and nav commands). + """ + + UNITREE_G1_SONIC = "unitree_g1_sonic" + """ + The Unitree G1 robot with SONIC whole-body controller. VLA action space is SONIC latents. + """ + + SIMPLER_ENV_GOOGLE = "simpler_env_google" + """ + The SimplerEnv Google robot. + """ + + SIMPLER_ENV_WIDOWX = "simpler_env_widowx" + """ + The SimplerEnv WidowX robot. + """ + + LIBERO_PANDA = "libero_sim" + """ + The LIBERO Panda robot (used for LIBERO-Goal, LIBERO-Object, LIBERO-Spatial, LIBERO-10). + """ + + # New embodiment during post-training + NEW_EMBODIMENT = "new_embodiment" + """ + Any new embodiment. + """ + + @classmethod + def resolve(cls, tag: "str | EmbodimentTag") -> "EmbodimentTag": + """Resolve a string to an EmbodimentTag, case-insensitively. + + Matches by enum **name** first (e.g. ``"xdof"`` -> ``XDOF``), then by + enum **value** (e.g. ``"xdof_relative_eef_relative_joint"`` -> ``XDOF``). + + Raises: + ValueError: If *tag* does not match any known embodiment. + """ + if isinstance(tag, cls): + return tag + key = tag.strip() + key_lower = key.lower() + # Match by enum name (case-insensitive) + for member in cls: + if member.name.lower() == key_lower: + return member + # Match by enum value (case-insensitive) + for member in cls: + if member.value.lower() == key_lower: + return member + + def _fmt(tags): + return "\n".join(f" {m.name:40s} -> {m.value}" for m in tags) + + msg = ( + f"Unknown embodiment tag: {tag!r}\n\n" + f" Base model tags (work with nvidia/GR00T-N1.7-3B):\n" + f"{_fmt(PRETRAIN_TAGS)}\n\n" + f" Posttrain tags (require a finetuned checkpoint):\n" + f"{_fmt(POSTTRAIN_TAGS)}\n\n" + f" Finetuning-only tags (for custom robots):\n" + f"{_fmt(FINETUNE_ONLY_TAGS)}" + ) + raise ValueError(msg) + + @classmethod + def reverse_lookup(cls, value: str) -> "str": + """Map a tag value string back to its enum name, or return the value as-is.""" + for member in cls: + if member.value == value: + return member.name + return value + + +# Module-level tag category sets (cannot be Enum class attributes). +PRETRAIN_TAGS: frozenset[EmbodimentTag] = frozenset( + { + EmbodimentTag.OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT, + EmbodimentTag.XDOF, + EmbodimentTag.XDOF_SUBTASK, + EmbodimentTag.REAL_G1, + EmbodimentTag.REAL_R1_PRO_SHARPA, + EmbodimentTag.REAL_R1_PRO_SHARPA_HUMAN, + EmbodimentTag.REAL_R1_PRO_SHARPA_MAXINSIGHTS, + EmbodimentTag.REAL_R1_PRO_SHARPA_MECKA, + } +) +"""Tags baked into the base model (nvidia/GR00T-N1.7-3B) — usable without finetuning.""" + +POSTTRAIN_TAGS: frozenset[EmbodimentTag] = frozenset( + { + EmbodimentTag.UNITREE_G1, + EmbodimentTag.UNITREE_G1_SONIC, + EmbodimentTag.SIMPLER_ENV_GOOGLE, + EmbodimentTag.SIMPLER_ENV_WIDOWX, + EmbodimentTag.LIBERO_PANDA, + } +) +"""Tags that require a finetuned checkpoint.""" + +FINETUNE_ONLY_TAGS: frozenset[EmbodimentTag] = frozenset( + { + EmbodimentTag.NEW_EMBODIMENT, + } +) +"""Tags for custom robots (finetuning only, not in any shipped checkpoint).""" diff --git a/vllm_omni/diffusion/models/gr00t/dataio/interfaces.py b/vllm_omni/diffusion/models/gr00t/dataio/interfaces.py new file mode 100644 index 00000000000..ed8e2450004 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/interfaces.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any + +import numpy as np +from transformers import ProcessorMixin + +from vllm_omni.diffusion.models.gr00t.dataio.types import EmbodimentTag, ModalityConfig + + +class BaseProcessor(ProcessorMixin): + def __call__(self, messages: list[dict[str, Any]]) -> dict[str, Any]: + """ + Process a list of messages and return a dictionary of model inputs. + + Args: + messages (list[dict[str, Any]]): List of messages to process. + + Returns: + dict[str, Any]: Dictionary of model inputs. + + Example: + >>> processor = BaseProcessor() + >>> messages = [ + >>> {"type": MessageType.START_OF_EPISODE.value, "content": ""}, + >>> {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, + >>> {"type": MessageType.TEXT.value, "role" : "user", "content": "Please give me the apple"}, + >>> {"type": MessageType.TEXT.value, "role" : "assistant", + >>> "content": "I need to move my left hand to get the apple"}, + >>> {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, + >>> {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, + >>> {"type": MessageType.END_OF_EPISODE.value, "content": ""}, + >>> ] + >>> model_input = processor(messages) + >>> print(model_input) + """ + raise NotImplementedError("Subclasses must implement __call__") + + def decode_action( + self, + action: np.ndarray, + embodiment_tag: EmbodimentTag, + state: dict[str, np.ndarray] | None = None, + ) -> dict[str, np.ndarray]: + """Decode the action from the model output.""" + raise NotImplementedError("Subclasses must implement decode_action") + + @property + def collator(self): + raise NotImplementedError("Subclasses must implement collator") + + @abstractmethod + def set_statistics(self, statistics: dict[str, Any], override: bool = False) -> None: + """Set normalization statistics.""" + pass + + def train(self): + self.training = True + + def eval(self): + self.training = False + + def get_modality_configs(self) -> dict[str, dict[str, ModalityConfig]]: + """Get the modality configurations. + + Returns: + dict[str, dict[str, ModalityConfig]]: The modality configurations, where + modality_configs[embodiment_tag][modality] = ModalityConfig + """ + return getattr(self, "modality_configs") + + +class ShardedDataset(ABC): + def __init__(self, dataset_path): + self.dataset_path = dataset_path + + @abstractmethod + def __len__(self) -> int: + """Return the number of shards.""" + pass + + @abstractmethod + def get_shard_length(self, idx: int) -> int: + """Get the length of the shard at index idx.""" + pass + + @abstractmethod + def get_shard(self, idx: int) -> list: + """Get the shard at index idx.""" + pass + + def set_processor(self, processor: BaseProcessor): + self.processor = processor + + def get_dataset_statistics(self) -> dict[str, Any]: + """Get the dataset statistics. This is only required for dataloaders for robtics datasets.""" + raise NotImplementedError() + + +# # Example chat formats (processor input) +# # Single step +# messages = [ +# {"type": "episode_step", "content": VLAStepData}, +# ] +# # Single episode +# messages = [ +# {"type": MessageType.START_OF_EPISODE.value, "content": ""}, +# {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, +# {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, +# {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, +# {"type": MessageType.END_OF_EPISODE.value, "content": ""}, +# ] +# # Multiple episodes +# messages = [ +# {"type": MessageType.START_OF_EPISODE.value, "content": ""}, +# {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, +# {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, +# {"type": MessageType.END_OF_EPISODE.value, "content": ""}, +# {"type": MessageType.START_OF_EPISODE.value, "content": ""}, +# {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, +# {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, +# {"type": MessageType.END_OF_EPISODE.value, "content": ""}, +# ] + +# # Example usage +# messages = dataset[idx] +# model_input = processor(messages) +# model_output = model(**model_input) # or model.generate(**model_input) +# decoded_action = processor.decode_action(model_output) diff --git a/vllm_omni/diffusion/models/gr00t/dataio/state_action/__init__.py b/vllm_omni/diffusion/models/gr00t/dataio/state_action/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/state_action/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/diffusion/models/gr00t/dataio/state_action/action_chunking.py b/vllm_omni/diffusion/models/gr00t/dataio/state_action/action_chunking.py new file mode 100644 index 00000000000..28a9f030df6 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/state_action/action_chunking.py @@ -0,0 +1,666 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +from typing import Generic, TypeVar + +import numpy as np +from numpy.typing import NDArray +from scipy import interpolate +from scipy.spatial.transform import Rotation, Slerp + +from vllm_omni.diffusion.models.gr00t.dataio.state_action.pose import EndEffectorPose, JointPose, Pose +from vllm_omni.diffusion.models.gr00t.dataio.types import ActionFormat + +PoseType = TypeVar("PoseType", bound=Pose) + + +class ActionChunk(Generic[PoseType]): + """ + Abstract base class for robot action chunking. + + This class provides common functionality for different action chunking types + including relative and delta action chunking computation with optional reference frames, + interpolation, and format conversion. + """ + + def __init__( + self, + poses: Sequence[PoseType], + times: Sequence[float] | NDArray[np.float64] | None = None, + ): + """ + Initialize action chunking from a list of poses. + + Args: + poses: Sequence of Pose objects + times: Optional sequence of timestamps for each pose. If None, assumes + uniform spacing starting from 0 with step 1.0 + + Raises: + ValueError: If action chunking is empty or times length doesn't match poses + """ + if not poses: + raise ValueError("ActionChunk must contain at least one pose") + + self._poses: list[PoseType] = list(poses) + + # Set up times + if times is None: + self._times = np.arange(len(poses), dtype=np.float64) + else: + if len(times) != len(poses): + raise ValueError("Number of times must match number of poses") + self._times = np.array(times, dtype=np.float64) + + @property + def poses(self) -> list[PoseType]: + """Get the list of poses""" + return self._poses.copy() + + @property + def times(self) -> NDArray[np.float64]: + """Get the timestamps""" + return self._times.copy() + + @property + def num_poses(self) -> int: + """Get the number of poses in the action chunking""" + return len(self._poses) + + def relative_chunking(self, reference_frame: PoseType | None = None) -> "ActionChunk[PoseType]": + """ + Compute the relative action chunking with respect to a reference frame. + + If reference_frame is None, uses the first pose in the action chunking as reference. + All poses are transformed to be relative to the reference frame. + + Args: + reference_frame: Optional reference pose. If None, uses first pose. + + Returns: + A new ActionChunk of the same type where all poses are relative to the reference frame. + """ + if not self._poses: + return self.__class__([], times=[]) + + # Use the first pose as the reference if one is not provided. + ref_pose = reference_frame if reference_frame is not None else self._poses[0] + + # Use the polymorphic subtraction defined in the Pose subclasses. + # The subtraction returns the same type as the operands + relative_poses: list[PoseType] = [pose - ref_pose for pose in self._poses] # type: ignore[misc] + + # Return a new instance of the same action chunking class + # (e.g., JointActionChunk or EndEffectorActionChunk) + return self.__class__(relative_poses, times=self.times) + + def delta_chunking(self, reference_frame: PoseType | None = None) -> "ActionChunk[PoseType]": + """ + Compute the delta action chunking where each pose represents the relative + transformation from the previous frame. + + If reference_frame is provided, it is treated as the first frame, and the + first delta will be from reference_frame to the first pose in the action chunking. + Otherwise, the first pose in the delta action chunking will be the identity/zero transformation. + + Args: + reference_frame: Optional reference pose to use as the first frame. + + Returns: + A new ActionChunk of the same type where each pose is relative to the previous pose. + """ + if not self._poses: + return self.__class__([], times=[]) + + delta_poses: list[PoseType] = [] + + # Determine the initial reference for the very first pose. + # If a reference_frame is given, the first delta is pose[0] - reference_frame. + # If not, the first delta is pose[0] - pose[0], resulting in an identity/zero pose. + prev_pose = reference_frame if reference_frame is not None else self._poses[0] + + for current_pose in self._poses: + delta: PoseType = current_pose - prev_pose # type: ignore[assignment] + delta_poses.append(delta) + prev_pose = current_pose # Update the reference for the next step + + return self.__class__(delta_poses, times=self.times.tolist()) + + def to_absolute_chunking(self, reference_frame: PoseType) -> "ActionChunk[PoseType]": + """ + Convert a relative action chunking to an absolute action chunking by applying + the relative poses on top of a reference frame. + + This is the inverse operation of relative_chunking(). Each relative pose + is composed with the reference frame to produce absolute poses. + + Args: + reference_frame: The reference pose to apply the relative action chunking on top of. + + Returns: + A new ActionChunk of the same type with absolute poses. + + Raises: + NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError("Subclasses must implement to_absolute_chunking") + + def interpolate( + self, + num_points: int | None = None, + times: NDArray[np.float64] | None = None, + ) -> "ActionChunk": + """ + Interpolate the action chunking to generate intermediate poses. + Must be implemented by subclasses. + + Args: + num_points: Number of evenly-spaced points to generate + times: Specific timestamps at which to interpolate + + Returns: + A new ActionChunk with interpolated poses + """ + raise NotImplementedError("Subclasses must implement interpolate") + + def to(self, action_format: ActionFormat) -> NDArray[np.float64]: + """ + Convert action chunking to the specified action format. + Must be implemented by subclasses. + + Args: + action_format: The desired output format + + Returns: + Array in the requested format + + Raises: + NotImplementedError: If not implemented by subclass + """ + raise NotImplementedError("Subclasses must implement to method") + + def __len__(self) -> int: + """Return the number of poses in the action chunking""" + return len(self._poses) + + def __getitem__(self, index: int) -> PoseType: + """Get a pose by index""" + return self._poses[index] + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(num_poses={len(self._poses)}, " + f"time_range=[{self._times[0]:.2f}, {self._times[-1]:.2f}])" + ) + + +class JointActionChunk(ActionChunk[JointPose]): + """ + Represents action chunking in joint space as a sequence of joint configurations. + + Examples: + # Create a joint action chunking + joint_poses = [ + JointPose([0.0, 0.0, 0.0]), + JointPose([0.5, 0.5, 0.5]), + JointPose([1.0, 1.0, 1.0]), + ] + action_chunking = JointActionChunk(joint_poses) + + # Get relative trajectory (all poses relative to first pose) + relative_traj = action_chunking.relative_chunking() + + # Get relative trajectory with custom reference + reference = JointPose([0.1, 0.1, 0.1]) + relative_traj = action_chunking.relative_chunking(reference_frame=reference) + + # Get delta trajectory (incremental changes) + delta_traj = action_chunking.delta_chunking() + + # Convert relative trajectory back to absolute + reference = JointPose([0.1, 0.1, 0.1]) + absolute_traj = relative_traj.to_absolute_chunking(reference_frame=reference) + + # Interpolate trajectory + interpolated = action_chunking.interpolate(num_points=10) + + # Convert to desired format + from vllm_omni.diffusion.models.gr00t.dataio.types import ActionFormat + array_data = action_chunking.to(ActionFormat.DEFAULT) # Returns joint array + """ + + def __init__( + self, + poses: Sequence[JointPose], + times: Sequence[float] | NDArray[np.float64] | None = None, + ): + """ + Initialize a joint trajectory from a list of joint poses. + + Args: + poses: Sequence of JointPose objects + times: Optional sequence of timestamps for each pose + + Raises: + TypeError: If poses are not all JointPose objects + """ + # Validate all poses are JointPose + if not all(isinstance(p, JointPose) for p in poses): + raise TypeError("All poses must be JointPose objects for JointActionChunk") + + super().__init__(poses, times) + + def interpolate( + self, + num_points: int | None = None, + times: NDArray[np.float64] | None = None, + ) -> "JointActionChunk": + """ + Interpolate the joint action chunking to generate intermediate configurations. + + Uses linear interpolation for joint values. + + Args: + num_points: Number of evenly-spaced points to generate (including endpoints). + Only used if times is None. + times: Specific timestamps at which to interpolate. If provided, + num_points is ignored. + + Returns: + A new JointActionChunk with interpolated poses + + Raises: + ValueError: If neither num_points nor times is provided, or if + interpolation times are outside the trajectory range + """ + if num_points is None and times is None: + raise ValueError("Must provide either num_points or times") + + if len(self._poses) < 2: + raise ValueError("Need at least 2 poses for interpolation") + + # Prepare data: extract joint values + timestamps = self._times.copy() + joint_values = np.array([pose.joints for pose in self._poses]) # (N, num_joints) + + # Find and remove non-monotonic timestamps + drop_indices = [idx for idx in range(1, len(timestamps)) if timestamps[idx] <= timestamps[idx - 1]] + + if drop_indices: + for idx in drop_indices: + print( + f"Dropping timestamp pair - Previous: {timestamps[idx - 1]}, " + f"Current: {timestamps[idx]} at index {idx}" + ) + timestamps = np.delete(timestamps, drop_indices) + joint_values = np.delete(joint_values, drop_indices, axis=0) + + # Check if we still have enough poses after cleanup + if len(timestamps) < 2: + raise ValueError("Need at least 2 poses with monotonic timestamps for interpolation") + + # Create interpolator + joint_interp = interpolate.interp1d(timestamps, joint_values, kind="linear", axis=0) + + # Generate interpolation times if not provided + if times is None: + assert num_points is not None # Type narrowing for type checker + interp_times = np.linspace(timestamps[0], timestamps[-1], num_points) + else: + interp_times = np.array(times, dtype=np.float64) + + # Check that interpolation times are within bounds + if np.any(interp_times < timestamps[0]) or np.any(interp_times > timestamps[-1]): + raise ValueError(f"Interpolation times must be within [{timestamps[0]}, {timestamps[-1]}]") + + # Interpolate joint values + interp_joint_values = joint_interp(interp_times) + + # Create interpolated poses + joint_names = self._poses[0].joint_names + interpolated_poses = [ + JointPose(joints=interp_joint_values[i], joint_names=joint_names) for i in range(len(interp_times)) + ] + + return JointActionChunk(interpolated_poses, times=interp_times) + + def to_array(self) -> NDArray[np.float64]: + """ + Convert trajectory to array of joint values. + + Returns: + Array with shape (N, num_joints) where N is the number of poses + """ + return np.array([pose.joints for pose in self._poses]) + + def to_absolute_chunking(self, reference_frame: JointPose) -> "JointActionChunk": + """ + Convert a relative joint action chunking to an absolute action chunking by adding + the relative joint positions to the reference frame. + + This is the inverse operation of relative_chunking(). Each relative joint + configuration is added to the reference frame to produce absolute joint positions. + + Args: + reference_frame: The reference joint pose to apply the relative trajectory on top of. + + Returns: + A new JointActionChunk with absolute joint positions. + + Raises: + ValueError: If joint dimensions don't match + """ + if not self._poses: + return JointActionChunk([], times=[]) + + if len(self._poses[0].joints) != len(reference_frame.joints): + raise ValueError( + f"Cannot apply relative trajectory: " + f"joint dimensions don't match ({len(self._poses[0].joints)} vs {len(reference_frame.joints)})" + ) + + # Add each relative pose to the reference frame + absolute_poses: list[JointPose] = [] + for relative_pose in self._poses: + absolute_joints = reference_frame.joints + relative_pose.joints + absolute_pose = JointPose(joints=absolute_joints, joint_names=reference_frame.joint_names) + absolute_poses.append(absolute_pose) + + return JointActionChunk(absolute_poses, times=self.times) + + def to(self, action_format: ActionFormat) -> NDArray[np.float64]: + """ + Convert trajectory to the desired format. + + Args: + action_format: The desired output format + + Returns: + Array in the requested format + + Raises: + ValueError: If the action format is not supported for joint trajectories + """ + if action_format == ActionFormat.DEFAULT: + return self.to_array() + else: + raise ValueError( + f"ActionFormat {action_format} is not supported for JointActionChunk. " + f"Only {ActionFormat.DEFAULT} is supported." + ) + + +class EndEffectorActionChunk(ActionChunk[EndEffectorPose]): + """ + Represents action chunking in Cartesian space as a sequence of end-effector poses. + + Examples: + # Create an end-effector action chunking + ee_poses = [ + EndEffectorPose(translation=[0, 0, 0], rotation=[1, 0, 0, 0], + rotation_type="quat", rotation_order="wxyz"), + EndEffectorPose(translation=[1, 0, 0], rotation=[0.707, 0, 0, 0.707], + rotation_type="quat", rotation_order="wxyz"), + EndEffectorPose(translation=[2, 0, 0], rotation=[0, 0, 0, 1], + rotation_type="quat", rotation_order="wxyz"), + ] + action_chunking = EndEffectorActionChunk(ee_poses) + + # Get relative trajectory (all poses relative to first pose) + relative_traj = action_chunking.relative_chunking() + + # Get relative trajectory with custom reference frame + reference = EndEffectorPose(translation=[0.5, 0, 0], rotation=[1,0,0,0], + rotation_type="quat", rotation_order="wxyz") + relative_traj = action_chunking.relative_chunking(reference_frame=reference) + + # Get delta trajectory + delta_traj = action_chunking.delta_chunking() + + # Convert relative trajectory back to absolute + reference = EndEffectorPose(translation=[0.5, 0, 0], rotation=[1,0,0,0], + rotation_type="quat", rotation_order="wxyz") + absolute_traj = relative_traj.to_absolute_chunking(reference_frame=reference) + + # Interpolate trajectory + interpolated = action_chunking.interpolate(num_points=10) + + # Convert to desired format + from vllm_omni.diffusion.models.gr00t.dataio.types import ActionFormat + homo_matrices = action_chunking.to(ActionFormat.DEFAULT) # (N, 4, 4) homogeneous matrices + xyz_rot6d = action_chunking.to(ActionFormat.XYZ_ROT6D) # (N, 9) xyz + rot6d + xyz_rotvec = action_chunking.to(ActionFormat.XYZ_ROTVEC) # (N, 6) xyz + rotvec + """ + + def __init__( + self, + poses: Sequence[EndEffectorPose], + times: Sequence[float] | NDArray[np.float64] | None = None, + ): + """ + Initialize an end-effector trajectory from a list of end-effector poses. + + Args: + poses: Sequence of EndEffectorPose objects + times: Optional sequence of timestamps for each pose + + Raises: + TypeError: If poses are not all EndEffectorPose objects + """ + # Validate all poses are EndEffectorPose + if not all(isinstance(p, EndEffectorPose) for p in poses): + raise TypeError("All poses must be EndEffectorPose objects for EndEffectorActionChunk") + + super().__init__(poses, times) + + @classmethod + def from_array(cls, data: np.ndarray, action_format: ActionFormat) -> "EndEffectorActionChunk": + """ + Create an EndEffectorActionChunk from a 2-D array using the specified action format. + + This is the inverse of ``.to(action_format)``. + + Args: + data: Array of shape (N, D) where D depends on the action_format. + action_format: The format that describes the layout of each row. + + Returns: + EndEffectorActionChunk with N poses. + """ + poses = [EndEffectorPose.from_action_format(row, action_format) for row in data] + return cls(poses) + + def interpolate( + self, + num_points: int | None = None, + times: NDArray[np.float64] | None = None, + ) -> "EndEffectorActionChunk": + """ + Interpolate the action chunking to generate intermediate poses. + + Uses linear interpolation for translation and SLERP (Spherical Linear + Interpolation) for rotation. + + Args: + num_points: Number of evenly-spaced points to generate (including endpoints). + Only used if times is None. + times: Specific timestamps at which to interpolate. If provided, + num_points is ignored. + + Returns: + A new EndEffectorActionChunk with interpolated poses + + Raises: + ValueError: If neither num_points nor times is provided, or if + interpolation times are outside the trajectory range + """ + if num_points is None and times is None: + raise ValueError("Must provide either num_points or times") + + if len(self._poses) < 2: + raise ValueError("Need at least 2 poses for interpolation") + + # Prepare data: extract positions and rotations + timestamps = self._times.copy() + homogeneous_matrices = np.array([pose.homogeneous for pose in self._poses]) + positions = homogeneous_matrices[:, :3, 3] + rotations = Rotation.from_matrix(homogeneous_matrices[:, :3, :3]) + + # Find indices where timestamps are not monotonically increasing + drop_indices = [idx for idx in range(1, len(timestamps)) if timestamps[idx] <= timestamps[idx - 1]] + + # Remove the problematic timestamps and corresponding data + if drop_indices: + for idx in drop_indices: + print( + f"Dropping timestamp pair - Previous: {timestamps[idx - 1]}, " + f"Current: {timestamps[idx]} at index {idx}" + ) + timestamps = np.delete(timestamps, drop_indices) + positions = np.delete(positions, drop_indices, axis=0) + rotations = Rotation.from_matrix(np.delete(homogeneous_matrices[:, :3, :3], drop_indices, axis=0)) + + # Check if we still have enough poses after cleanup + if len(timestamps) < 2: + raise ValueError("Need at least 2 poses with monotonic timestamps for interpolation") + + # Create interpolators + pos_interp = interpolate.interp1d(timestamps, positions, kind="linear", axis=0) + rot_interp = Slerp(timestamps, rotations) + + # Generate interpolation times if not provided + if times is None: + assert num_points is not None # Type narrowing for type checker + interp_times = np.linspace(timestamps[0], timestamps[-1], num_points) + else: + interp_times = np.array(times, dtype=np.float64) + + # Check that interpolation times are within bounds + if np.any(interp_times < timestamps[0]) or np.any(interp_times > timestamps[-1]): + raise ValueError(f"Interpolation times must be within [{timestamps[0]}, {timestamps[-1]}]") + + # Interpolate positions and rotations + interp_positions = pos_interp(interp_times) + interp_rotations = rot_interp(interp_times) + + # Create interpolated poses + interpolated_poses = [] + for i in range(len(interp_times)): + pose = EndEffectorPose( + translation=interp_positions[i], + rotation=interp_rotations[i].as_matrix(), + rotation_type="matrix", + ) + interpolated_poses.append(pose) + + return EndEffectorActionChunk(interpolated_poses, times=interp_times) + + def to_homogeneous_matrices(self) -> NDArray[np.float64]: + """ + Convert trajectory to array of homogeneous transformation matrices. + + Returns: + Array of homogeneous matrices with shape (N, 4, 4) where N is the number of poses + """ + return np.array([pose.homogeneous for pose in self._poses]) + + def to_translation_rot6d(self) -> NDArray[np.float64]: + """ + Convert trajectory to array of translations and 6D rotations. + + Returns: + Array with shape (N, 9) - 3 for xyz + 6 for rot6d + """ + translations = np.array([pose.translation for pose in self._poses]) # (N, 3) + rotations = np.array([pose.rot6d for pose in self._poses]) # (N, 6) + + # Concatenate translation and rotation + xyz_rot6d = np.concatenate([translations, rotations], axis=1) # (N, 9) + + return xyz_rot6d + + def to_translation_rotvec(self) -> NDArray[np.float64]: + """ + Convert trajectory to array of translations and rotation vectors. + + Returns: + Array with shape (N, 6) - 3 for xyz + 3 for rotvec + """ + translations = np.array([pose.translation for pose in self._poses]) # (N, 3) + rotations = np.array([pose.rotvec for pose in self._poses]) # (N, 3) + + # Concatenate translation and rotation + xyz_rotvec = np.concatenate([translations, rotations], axis=1) # (N, 6) + + return xyz_rotvec + + def to_absolute_chunking(self, reference_frame: EndEffectorPose) -> "EndEffectorActionChunk": + """ + Convert a relative end-effector action chunking to an absolute action chunking by + composing each relative transformation with the reference frame. + + This is the inverse operation of relative_chunking(). Each relative pose + represents a transformation that is applied on top of the reference frame + to produce absolute poses. + + Args: + reference_frame: The reference end-effector pose to apply the relative trajectory on top of. + + Returns: + A new EndEffectorActionChunk with absolute poses. + """ + if not self._poses: + return EndEffectorActionChunk([], times=[]) + + # Get reference frame as homogeneous matrix + T_ref = reference_frame.homogeneous + + # Compose each relative transformation with the reference frame + absolute_poses: list[EndEffectorPose] = [] + for relative_pose in self._poses: + # Get relative transformation as homogeneous matrix + T_relative = relative_pose.homogeneous + + # Compose transformations: T_absolute = T_ref @ T_relative + T_absolute = T_ref @ T_relative + + # Create absolute pose from composed transformation + absolute_pose = EndEffectorPose(homogeneous=T_absolute) + absolute_poses.append(absolute_pose) + + return EndEffectorActionChunk(absolute_poses, times=self.times) + + def to(self, action_format: ActionFormat) -> NDArray[np.float64]: + """ + Convert trajectory to the desired format. + + Args: + action_format: The desired output format + + Returns: + Array in the requested format + + Raises: + ValueError: If the action format is not supported + """ + if action_format == ActionFormat.DEFAULT: + return self.to_homogeneous_matrices() + elif action_format == ActionFormat.XYZ_ROT6D: + return self.to_translation_rot6d() + elif action_format == ActionFormat.XYZ_ROTVEC: + return self.to_translation_rotvec() + else: + raise ValueError(f"Unsupported action format: {action_format}") diff --git a/vllm_omni/diffusion/models/gr00t/dataio/state_action/pose.py b/vllm_omni/diffusion/models/gr00t/dataio/state_action/pose.py new file mode 100644 index 00000000000..1cc714a3e2b --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/state_action/pose.py @@ -0,0 +1,721 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import TypeVar + +import numpy as np +from numpy.typing import NDArray +from scipy.spatial.transform import Rotation + +from vllm_omni.diffusion.models.gr00t.dataio.types import ActionFormat + +# TypeVar for self-type preservation in Pose operations +PoseT = TypeVar("PoseT", bound="Pose") + + +def invert_transformation(transform: NDArray[np.float64]) -> NDArray[np.float64]: + """ + Invert a homogeneous transformation matrix. + + Args: + transform: A 4x4 homogeneous transformation matrix + + Returns: + The inverse of the transformation matrix (4x4) + """ + R = transform[:3, :3] # Extract the rotation matrix + t = transform[:3, 3] # Extract the translation vector + + # Inverse of the rotation matrix is its transpose (since it's orthogonal) + R_inv = R.T + + # Inverse of the translation is -R_inv * t + t_inv = -R_inv @ t + + # Construct the inverse transformation matrix + T_inv = np.eye(4) + T_inv[:3, :3] = R_inv + T_inv[:3, 3] = t_inv + + return T_inv + + +def relative_transformation( + base_transform: NDArray[np.float64], target_transform: NDArray[np.float64] +) -> NDArray[np.float64]: + """ + Compute the relative transformation between two poses. + + Args: + base_transform: Initial 4x4 homogeneous transformation matrix + target_transform: Current 4x4 homogeneous transformation matrix + + Returns: + The relative transformation matrix (4x4) from base_transform to target_transform + """ + # Relative transformation is base_transform^{-1} * target_transform + T_relative = invert_transformation(base_transform) @ target_transform + return T_relative + + +class RotationType(Enum): + """Supported rotation representation types""" + + QUAT = "quat" + EULER = "euler" + ROTVEC = "rotvec" + MATRIX = "matrix" + ROT6D = "rot6d" + + +class EulerOrder(Enum): + """Common Euler angle conventions""" + + XYZ = "xyz" + ZYX = "zyx" + XZY = "xzy" + YXZ = "yxz" + YZX = "yzx" + ZXY = "zxy" + + +class QuatOrder(Enum): + """Quaternion ordering conventions""" + + WXYZ = "wxyz" # scalar-first (w, x, y, z) + XYZW = "xyzw" # scalar-last (x, y, z, w) + + +class Pose: + """ + Abstract base class for robot poses. + + This class provides common functionality for different pose representations + including relative pose computation via the subtraction operator. + """ + + pose_type: str + + def __sub__(self: PoseT, other: PoseT) -> PoseT: + """ + Compute relative transformation between two poses. + + For EndEffectorPose: Computes the relative transformation from other to self. + Result represents the transformation needed to go from other's frame to self's frame. + + For JointPose: Computes the joint-space difference (self - other). + + Args: + other: The reference pose to compute relative transformation from + + Returns: + Relative pose (same type as self) + + Raises: + TypeError: If poses are not of the same type + + Examples: + # End-effector poses + pose1 = EndEffectorPose(translation=[1, 0, 0], rotation=[1,0,0,0], + rotation_type="quat", rotation_order="wxyz") + pose2 = EndEffectorPose(translation=[2, 0, 0], rotation=[1,0,0,0], + rotation_type="quat", rotation_order="wxyz") + relative = pose2 - pose1 # Transformation from pose1 to pose2 + + # Joint poses + joint1 = JointPose([0.0, 0.5, 1.0]) + joint2 = JointPose([0.1, 0.6, 1.2]) + joint_diff = joint2 - joint1 # Joint differences: [0.1, 0.1, 0.2] + """ + if type(self) is not type(other): + raise TypeError( + f"Cannot compute relative transformation between different pose types: " + f"{type(self).__name__} and {type(other).__name__}" + ) + + return self._compute_relative(other) + + def _compute_relative(self: PoseT, other: PoseT) -> PoseT: + """ + Internal method to compute relative transformation. + Must be implemented by subclasses. + + Args: + other: The reference pose + + Returns: + Relative pose + """ + raise NotImplementedError("Subclasses must implement _compute_relative") + + def copy(self: PoseT) -> PoseT: + """ + Create a deep copy of this pose. + Must be implemented by subclasses. + + Returns: + New Pose instance with copied data + """ + raise NotImplementedError("Subclasses must implement copy") + + +class JointPose(Pose): + """ + Represents a robot configuration in joint space. + + This class stores joint angles/positions for a robot manipulator. + Unlike end-effector poses, joint poses represent the configuration + of all joints in the kinematic chain. + + Examples: + # Create a 6-DOF joint configuration + joint_pose = JointPose( + joints=[0.0, -np.pi/4, np.pi/2, 0.0, np.pi/4, 0.0], + joint_names=["shoulder_pan", "shoulder_lift", "elbow", + "wrist_1", "wrist_2", "wrist_3"] + ) + + # Create with default joint names + joint_pose = JointPose(joints=[0.0, 0.5, 1.0]) + + # Get as dictionary + joint_dict = joint_pose.to_dict() # {"joint_0": 0.0, ...} + + # Access individual joints + first_joint = joint_pose.joints[0] + num_joints = joint_pose.num_joints + + # Compute relative joint displacement + joint1 = JointPose([0.0, 0.5, 1.0]) + joint2 = JointPose([0.1, 0.6, 1.2]) + relative = joint2 - joint1 # [0.1, 0.1, 0.2] + """ + + pose_type = "joint" + + def __init__( + self, + joints: list | np.ndarray, + joint_names: list | None = None, + ): + """ + Initialize a joint pose. + + Args: + joints: Joint angles/positions as array-like of shape (n,) + joint_names: Optional list of names for each joint. If None, + defaults to ["joint_0", "joint_1", ...] + """ + super().__init__() + self.joints = np.array(joints, dtype=np.float64) + + # Set defaults and validate joint_names + if joint_names is None: + self.joint_names = [f"joint_{i}" for i in range(len(self.joints))] + else: + if len(joint_names) != len(self.joints): + raise ValueError( + f"Number of joint names ({len(joint_names)}) must match number of joints ({len(self.joints)})" + ) + self.joint_names = joint_names + + @property + def num_joints(self) -> int: + """ + Get the number of joints. + + Returns: + Number of joints in the configuration + """ + return len(self.joints) + + def to_dict(self) -> dict: + """ + Convert joint configuration to dictionary. + + Returns: + Dictionary mapping joint names to joint values + """ + return dict(zip(self.joint_names, self.joints)) + + def _compute_relative(self, other): # type: ignore[override] + """ + Compute relative joint displacement. + + Args: + other: Reference joint pose + + Returns: + JointPose representing the joint-space difference (self - other) + + Raises: + ValueError: If joint dimensions don't match + """ + if len(self.joints) != len(other.joints): + raise ValueError( + f"Cannot compute relative joint pose: " + f"joint dimensions don't match ({len(self.joints)} vs {len(other.joints)})" + ) + + relative_joints = self.joints - other.joints + return JointPose(joints=relative_joints, joint_names=self.joint_names) + + def copy(self) -> "JointPose": + """ + Create a deep copy of this joint pose. + + Returns: + New JointPose instance with copied data + """ + return JointPose( + joints=self.joints.copy(), + joint_names=self.joint_names.copy(), + ) + + def __repr__(self) -> str: + if len(self.joints) <= 6: + joints_str = np.array2string(self.joints, precision=4, suppress_small=True) + else: + joints_str = f"[{self.joints[0]:.4f}, ..., {self.joints[-1]:.4f}] ({len(self.joints)} joints)" + + return f"JointPose(joints={joints_str})" + + def __eq__(self, other) -> bool: + if not isinstance(other, JointPose): + return False + return np.allclose(self.joints, other.joints) and self.joint_names == other.joint_names + + def __getitem__(self, index) -> float | NDArray[np.float64]: + """Allow indexing: joint_pose[0] returns first joint value""" + return self.joints[index] + + def __len__(self) -> int: + """Allow len(): len(joint_pose) returns number of joints""" + return len(self.joints) + + +class EndEffectorPose(Pose): + """ + Represents a single end-effector pose with translation and rotation components. + + This class handles Cartesian space representations of robot end-effector poses, + supporting multiple rotation representations (quaternions, Euler angles, rotation + vectors, rotation matrices, etc.). + + Examples: + # Create with quaternion (wxyz order) + pose = EndEffectorPose( + translation=[1.0, 2.0, 3.0], + rotation=[1.0, 0.0, 0.0, 0.0], + rotation_type="quat", + rotation_order="wxyz" + ) + + # Create with Euler angles (degrees by default) + pose = EndEffectorPose( + translation=[1, 2, 3], + rotation=[0, 0, 90], + rotation_type="euler", + rotation_order="xyz" + ) + + # Create with Euler angles in radians + pose = EndEffectorPose( + translation=[1, 2, 3], + rotation=[0, 0, np.pi/2], + rotation_type="euler", + rotation_order="xyz", + degrees=False + ) + + # Create from homogeneous matrix + H = np.eye(4) + H[:3, 3] = [1, 2, 3] + pose = EndEffectorPose(homogeneous=H) + + # Convert between representations + quat_wxyz = pose.to_rotation("quat", "wxyz") + euler_zyx = pose.to_rotation("euler", "zyx") + rot6d = pose.to_rotation("rot6d") + + # Compute relative transformation + pose1 = EndEffectorPose(translation=[1, 0, 0], rotation=[1,0,0,0], + rotation_type="quat", rotation_order="wxyz") + pose2 = EndEffectorPose(translation=[2, 0, 0], rotation=[1,0,0,0], + rotation_type="quat", rotation_order="wxyz") + relative = pose2 - pose1 # Transformation from pose1's frame to pose2's frame + """ + + pose_type = "end_effector" + + def __init__( + self, + translation: list | np.ndarray | None = None, + rotation: list | np.ndarray | None = None, + rotation_type: str | None = None, + rotation_order: str | None = None, + homogeneous: np.ndarray | None = None, + degrees: bool = True, + ): + """ + Initialize an end-effector pose. + + Args: + translation: Translation vector [x, y, z] + rotation: Rotation in specified format + rotation_type: Type of rotation ("quat", "euler", "rotvec", "matrix", "rot6d") + rotation_order: Order/convention for the rotation type + homogeneous: Homogeneous transformation matrix (4, 4) + If provided, overrides translation and rotation + degrees: For Euler angles, whether the input is in degrees (default True) + """ + super().__init__() + + # Cache for homogeneous matrix + self._homogeneous_cache: NDArray[np.float64] | None = None + self._cache_valid = False + + # Handle homogeneous matrix input + if homogeneous is not None: + self._from_homogeneous(homogeneous) + return + + # Store translation + self._translation = np.array(translation) if translation is not None else np.zeros(3) + + # Store rotation as scipy Rotation object internally + if rotation is not None: + if rotation_type is None: + raise ValueError("rotation_type must be specified when rotation is provided") + self._set_rotation(rotation, rotation_type, rotation_order, degrees) + else: + self._rotation = Rotation.identity() + + def _from_homogeneous(self, homogeneous: np.ndarray): + """Initialize from homogeneous transformation matrix""" + homogeneous = np.array(homogeneous) + + # Extract translation (last column, first 3 rows) + self._translation = homogeneous[:3, 3] + + # Extract rotation matrix (top-left 3x3) + rotation_matrix = homogeneous[:3, :3] + + # Create Rotation object from matrix + self._rotation = Rotation.from_matrix(rotation_matrix) + + @staticmethod + def _rot6d_to_matrix(rot6d: np.ndarray) -> np.ndarray: + """ + Convert 6D rotation representation to rotation matrix. + + Args: + rot6d: 6D rotation as (6,) array - first two rows of rotation matrix flattened + + Returns: + Rotation matrix (3, 3) + """ + rot6d = rot6d.reshape(2, 3) + + # First two rows + row1 = rot6d[0] + row2 = rot6d[1] + + # Normalize first row + row1 = row1 / np.linalg.norm(row1) + + # Gram-Schmidt orthogonalization for second row + row2 = row2 - np.dot(row1, row2) * row1 + row2 = row2 / np.linalg.norm(row2) + + # Third row is cross product + row3 = np.cross(row1, row2) + + # Construct rotation matrix + rotation_matrix = np.vstack([row1, row2, row3]) + + return rotation_matrix + + @staticmethod + def _matrix_to_rot6d(rotation_matrix: np.ndarray) -> np.ndarray: + """ + Convert rotation matrix to 6D rotation representation. + + Args: + rotation_matrix: Rotation matrix (3, 3) + + Returns: + 6D rotation - (6,) array (first two rows flattened) + """ + return rotation_matrix[:2, :].flatten() + + def _set_rotation( + self, + rotation: list | np.ndarray, + rotation_type: str, + rotation_order: str | None = None, + degrees: bool = True, + ): + """Internal method to set rotation from various representations""" + rotation = np.array(rotation) + rot_type = RotationType(rotation_type.lower()) + + if rot_type == RotationType.QUAT: + quat_order = QuatOrder(rotation_order.lower()) if rotation_order else QuatOrder.WXYZ + if quat_order == QuatOrder.WXYZ: + # scipy uses xyzw order, so convert + quat_xyzw = np.array([rotation[1], rotation[2], rotation[3], rotation[0]]) + else: + quat_xyzw = rotation + self._rotation = Rotation.from_quat(quat_xyzw) + + elif rot_type == RotationType.EULER: + euler_order = EulerOrder(rotation_order.lower()) if rotation_order else EulerOrder.XYZ + self._rotation = Rotation.from_euler(euler_order.value, rotation, degrees=degrees) + + elif rot_type == RotationType.ROTVEC: + self._rotation = Rotation.from_rotvec(rotation) + + elif rot_type == RotationType.MATRIX: + self._rotation = Rotation.from_matrix(rotation) + + elif rot_type == RotationType.ROT6D: + rotation_matrix = self._rot6d_to_matrix(rotation) + self._rotation = Rotation.from_matrix(rotation_matrix) + + else: + raise ValueError(f"Unknown rotation type: {rotation_type}") + + # Invalidate cache + self._cache_valid = False + + @property + def translation(self) -> np.ndarray: + """ + Get translation vector. + + Returns: + Translation array - shape (3,) + """ + return self._translation.copy() + + @property + def quat_wxyz(self) -> np.ndarray: + """Get rotation as quaternion in wxyz order (w, x, y, z)""" + return self.to_rotation("quat", "wxyz") + + @property + def quat_xyzw(self) -> np.ndarray: + """Get rotation as quaternion in xyzw order (x, y, z, w)""" + return self.to_rotation("quat", "xyzw") + + @property + def euler_xyz(self) -> np.ndarray: + """Get rotation as Euler angles in xyz order (degrees)""" + return self.to_rotation("euler", "xyz") + + @property + def rotvec(self) -> np.ndarray: + """Get rotation as rotation vector (axis-angle)""" + return self.to_rotation("rotvec") + + @property + def rotation_matrix(self) -> np.ndarray: + """Get rotation as 3x3 rotation matrix""" + return self.to_rotation("matrix") + + @property + def rot6d(self) -> np.ndarray: + """Get rotation as 6D representation (first two rows of rotation matrix)""" + return self.to_rotation("rot6d") + + @property + def xyz_rot6d(self) -> np.ndarray: + """Get pose as concatenated translation and 6D rotation (9,)""" + return np.concatenate([self._translation, self.rot6d]) + + @property + def xyz_rotvec(self) -> np.ndarray: + """Get pose as concatenated translation and rotation vector (6,)""" + return np.concatenate([self._translation, self.rotvec]) + + @property + def homogeneous(self) -> np.ndarray: + """ + Get homogeneous transformation matrix. + + Returns: + Homogeneous matrix - shape (4, 4) + """ + if not self._cache_valid: + self._homogeneous_cache = self._compute_homogeneous() + self._cache_valid = True + assert self._homogeneous_cache is not None + return self._homogeneous_cache.copy() + + def _compute_homogeneous(self) -> np.ndarray: + """Compute homogeneous transformation matrix""" + H = np.eye(4) + H[:3, :3] = self._rotation.as_matrix() + H[:3, 3] = self._translation + return H + + def to_rotation( + self, + rotation_type: str, + rotation_order: str | None = None, + degrees: bool = True, + ) -> np.ndarray: + """ + Get rotation in specified representation. + + Args: + rotation_type: Desired type ("quat", "euler", "rotvec", "matrix", "rot6d") + rotation_order: Order/convention for the rotation type + degrees: For Euler angles, return in degrees (default True) + + Returns: + Rotation in requested format + - Shape (4,) for quat + - Shape (3,) for euler/rotvec + - Shape (6,) for rot6d + - Shape (3, 3) for matrix + """ + rot_type = RotationType(rotation_type.lower()) + + if rot_type == RotationType.ROT6D: + rotation_matrix = self._rotation.as_matrix() + return self._matrix_to_rot6d(rotation_matrix) + + if rot_type == RotationType.QUAT: + quat_order = QuatOrder(rotation_order.lower()) if rotation_order else QuatOrder.WXYZ + quat_xyzw = self._rotation.as_quat() + if quat_order == QuatOrder.WXYZ: + return np.array([quat_xyzw[3], quat_xyzw[0], quat_xyzw[1], quat_xyzw[2]]) + else: + return quat_xyzw + + elif rot_type == RotationType.EULER: + euler_order = EulerOrder(rotation_order.lower()) if rotation_order else EulerOrder.XYZ + return self._rotation.as_euler(euler_order.value, degrees=degrees) + + elif rot_type == RotationType.ROTVEC: + return self._rotation.as_rotvec() + + elif rot_type == RotationType.MATRIX: + return self._rotation.as_matrix() + + else: + raise ValueError(f"Unknown rotation type: {rotation_type}") + + def to_homogeneous(self) -> np.ndarray: + """ + Convert pose to homogeneous transformation matrix. + (Alias for the homogeneous property) + + Returns: + Homogeneous matrix - shape (4, 4) + """ + return self.homogeneous + + def set_rotation( + self, + rotation: list | np.ndarray, + rotation_type: str, + rotation_order: str | None = None, + degrees: bool = True, + ): + """ + Set rotation from specified representation. + + Args: + rotation: Rotation data + rotation_type: Type of rotation ("quat", "euler", "rotvec", "matrix", "rot6d") + rotation_order: Order/convention for the rotation type + degrees: For Euler angles, whether the input is in degrees (default True) + """ + self._set_rotation(rotation, rotation_type, rotation_order, degrees) + + def _compute_relative(self, other): # type: ignore[override] + """ + Compute relative transformation from other to self. + + The result represents the transformation needed to go from other's frame to self's frame. + Mathematically: T_relative = T_other^{-1} * T_self + + Args: + other: Reference end-effector pose + + Returns: + EndEffectorPose representing the relative transformation + """ + # Get homogeneous matrices + T_self = self.homogeneous + T_other = other.homogeneous + + # Compute relative transformation: T_other^{-1} * T_self + T_relative = relative_transformation(T_other, T_self) + + # Create new EndEffectorPose from relative transformation + return EndEffectorPose(homogeneous=T_relative) + + @classmethod + def from_action_format(cls, data: np.ndarray, action_format: ActionFormat) -> "EndEffectorPose": + """ + Create an EndEffectorPose from a flat array using the specified action format. + + This is the inverse of the xyz_rot6d / xyz_rotvec / homogeneous properties. + + Args: + data: Flat array whose layout depends on action_format. + action_format: One of ActionFormat.XYZ_ROT6D, XYZ_ROTVEC, or DEFAULT. + + Returns: + EndEffectorPose instance. + """ + if action_format == ActionFormat.XYZ_ROT6D: + return cls(translation=data[:3], rotation=data[3:], rotation_type="rot6d") + elif action_format == ActionFormat.XYZ_ROTVEC: + return cls(translation=data[:3], rotation=data[3:], rotation_type="rotvec") + elif action_format == ActionFormat.DEFAULT: + return cls(homogeneous=data.reshape(4, 4)) + else: + raise ValueError(f"Unsupported ActionFormat: {action_format}") + + def copy(self) -> "EndEffectorPose": + """ + Create a deep copy of this end-effector pose. + + Returns: + New EndEffectorPose instance with copied data + """ + return EndEffectorPose( + translation=self._translation.copy(), + rotation=self._rotation.as_quat(), + rotation_type="quat", + rotation_order="xyzw", + ) + + def __repr__(self) -> str: + quat = self.to_rotation("quat", "wxyz") + return f"EndEffectorPose(translation={self.translation}, rotation_quat_wxyz={quat})" + + def __eq__(self, other) -> bool: + if not isinstance(other, EndEffectorPose): + return False + return np.allclose(self._translation, other._translation) and np.allclose( + self._rotation.as_quat(), other._rotation.as_quat() + ) diff --git a/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py b/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py new file mode 100644 index 00000000000..6126c5610cf --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py @@ -0,0 +1,672 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unified processor for robot state and action data. + +Handles: +- State normalization (min/max, mean/std, sin/cos encoding) +- Action normalization +- Absolute <-> Relative action representation conversion +- Action processing with state dependency +""" + +from copy import deepcopy + +import numpy as np + +from vllm_omni.diffusion.models.gr00t.configs.embodiment.embodiment_configs import ( + ActionFormat, + ActionRepresentation, + ActionType, + ModalityConfig, +) +from vllm_omni.diffusion.models.gr00t.dataio.state_action.action_chunking import ( + EndEffectorActionChunk, + JointActionChunk, +) +from vllm_omni.diffusion.models.gr00t.dataio.state_action.pose import EndEffectorPose, JointPose +from vllm_omni.diffusion.models.gr00t.dataio.utils import ( + apply_sin_cos_encoding, + nested_dict_to_numpy, + normalize_values_meanstd, + normalize_values_minmax, + parse_modality_configs, + unnormalize_values_meanstd, + unnormalize_values_minmax, +) + + +class StateActionProcessor: + """ + Unified processor for robot state and action data. + + Handles: + - State normalization (min/max, mean/std, sin/cos encoding) + - Action normalization + - Absolute <-> Relative action representation conversion + - Action processing with state dependency + """ + + def __init__( + self, + modality_configs: dict[str, dict[str, ModalityConfig]], + statistics: (dict[str, dict[str, dict[str, dict[str, list[float]]]]] | None) = None, + use_percentiles: bool = False, + clip_outliers: bool = True, + apply_sincos_state_encoding: bool = False, + use_relative_action: bool = False, + ): + """ + Initialize unified state and action processor. + + Args: + modality_configs: Nested dict with structure: + {embodiment_tag: {modality: ModalityConfig}} + where modality in ["state", "action"] + Example: {"gr1": {"state": ModalityConfig(...), "action": ModalityConfig(...)}} + statistics: Optional nested dict with structure: + {embodiment_tag: {modality: {joint_group: {stat_type: values}}}} + where modality in ["state", "action", "relative_action"] + and stat_type in ["min", "max", "mean", "std", "q01", "q99"] + Example: {"gr1": {"state": {"left_arm": {"min": [...], "max": [...], ...}}}} + use_percentiles: Whether to use percentiles (q01/q99) instead of min/max + clip_outliers: Whether to clip normalized values to [-1, 1] + apply_sincos_state_encoding: Global flag to enable sin/cos encoding for states + """ + self.modality_configs = parse_modality_configs(modality_configs) + self.statistics: dict[str, dict[str, dict[str, dict[str, list[float]]]]] = {} + self.use_percentiles = use_percentiles + self.clip_outliers = clip_outliers + self.apply_sincos_state_encoding = apply_sincos_state_encoding + self.use_relative_action = use_relative_action + + # Normalization parameters computed from statistics + self.norm_params: dict[str, dict[str, dict[str, dict[str, np.ndarray]]]] = {} + # Format: norm_params[embodiment_tag][modality][joint_group][stat_type] + # where stat_type in ["min", "max", "mean", "std", "dim"] + + if statistics is not None: + self.set_statistics(statistics) + + self.train() + + def train(self): + self.training = True + + def eval(self): + self.training = False + + def set_statistics( + self, + statistics: dict[str, dict[str, dict[str, dict[str, list[float]]]]], + override: bool = False, + ) -> None: + """ + Set dataset statistics for normalization. + + Args: + statistics: Nested dict with structure: + {embodiment_tag: {modality: {joint_group: {stat_type: values}}}} + """ + for key in statistics: + if key not in self.statistics or override: + self.statistics[key] = deepcopy(statistics[key]) + else: + print(f"Embodiment tag {key} already in statistics, skipping updating") + self._compute_normalization_parameters() + + def _compute_normalization_parameters(self) -> None: + """Compute and cache normalization parameters from statistics for all embodiments and modalities.""" + for embodiment_tag in self.statistics: + self.norm_params[embodiment_tag] = {} + + for modality in ["state", "action"]: + if modality not in self.statistics[embodiment_tag]: + continue + + self.norm_params[embodiment_tag][modality] = {} + + for joint_group, stats in self.statistics[embodiment_tag][modality].items(): + if self.use_percentiles: + min_vals = np.array(stats["q01"]) + max_vals = np.array(stats["q99"]) + else: + min_vals = np.array(stats["min"]) + max_vals = np.array(stats["max"]) + + mean_vals = np.array(stats["mean"]) + std_vals = np.array(stats["std"]) + + # Compute range, ensuring it's not zero + range_vals = max_vals - min_vals + range_vals = np.maximum(range_vals, 1e-8) + + self.norm_params[embodiment_tag][modality][joint_group] = { + "min": min_vals, + "max": max_vals, + "dim": np.array(range_vals.shape[0]), + "mean": mean_vals, + "std": std_vals, + } + + # Override absolute action stats with relative stats where specified + if "action" in self.modality_configs[embodiment_tag]: + modality_keys = self.modality_configs[embodiment_tag]["action"].modality_keys + action_configs = self.modality_configs[embodiment_tag]["action"].action_configs + + if action_configs is not None: + for key, action_config in zip(modality_keys, action_configs): + if action_config.rep == ActionRepresentation.RELATIVE and self.use_relative_action: + if "relative_action" not in self.statistics[embodiment_tag]: + raise ValueError( + f"Relative action statistics required for embodiment '{embodiment_tag}' " + f"but 'relative_action' not found in statistics" + ) + if key not in self.statistics[embodiment_tag]["relative_action"]: + raise ValueError( + f"Relative action statistics required for key '{key}' " + f"in embodiment '{embodiment_tag}' but not found" + ) + action_dim = self.norm_params[embodiment_tag]["action"][key]["dim"] + self.norm_params[embodiment_tag]["action"][key] = nested_dict_to_numpy( + self.statistics[embodiment_tag]["relative_action"][key] + ) + self.norm_params[embodiment_tag]["action"][key]["dim"] = action_dim + + def apply_state( + self, + state: dict[str, np.ndarray], + embodiment_tag: str, + ) -> dict[str, np.ndarray]: + """ + Apply state processing (normalization, encoding). + + Args: + state: Dict mapping joint_group -> raw state values + Shape per group: (..., D) where D is state dimension + embodiment_tag: Embodiment identifier (e.g., "gr1") + + Returns: + Dict mapping joint_group -> processed state values + - Sin/cos encoded groups: (..., 2*D) + - Other groups: (..., D) + """ + normalized_values = {} + state = deepcopy(state) # Avoid modifying input + + # Get sin/cos embedding keys if enabled + sin_cos_keys = None + if self.apply_sincos_state_encoding: + state_config = self.modality_configs[embodiment_tag].get("state") + if state_config and hasattr(state_config, "sin_cos_embedding_keys"): + sin_cos_keys = state_config.sin_cos_embedding_keys + + for joint_group in self.modality_configs[embodiment_tag]["state"].modality_keys: + if joint_group not in state: + raise KeyError(f"Joint group '{joint_group}' not found in state dict for embodiment '{embodiment_tag}'") + + # Strategy 1: Sin/cos encoding (doubles dimension) + if sin_cos_keys and joint_group in sin_cos_keys: + normalized_values[joint_group] = apply_sin_cos_encoding(state[joint_group]) + + # Strategy 2: Mean/std normalization + elif ( + hasattr( + self.modality_configs[embodiment_tag]["state"], + "mean_std_embedding_keys", + ) + and self.modality_configs[embodiment_tag]["state"].mean_std_embedding_keys + and joint_group in self.modality_configs[embodiment_tag]["state"].mean_std_embedding_keys + ): + params = self.norm_params[embodiment_tag]["state"][joint_group] + normalized = normalize_values_meanstd(state[joint_group], params) + normalized_values[joint_group] = normalized + + # Strategy 3: Min/max normalization to [-1, 1] + else: + params = self.norm_params[embodiment_tag]["state"][joint_group] + normalized = normalize_values_minmax(state[joint_group], params) + + if self.clip_outliers: + normalized = np.clip(normalized, -1.0, 1.0) + + normalized_values[joint_group] = normalized + + return normalized_values + + def unapply_state( + self, + state: dict[str, np.ndarray], + embodiment_tag: str, + ) -> dict[str, np.ndarray]: + """ + Reverse state processing (denormalization). + + Args: + state: Dict mapping joint_group -> processed state values + embodiment_tag: Embodiment identifier + + Returns: + Dict mapping joint_group -> raw state values + + Raises: + ValueError: If attempting to reverse sin/cos encoding (not reversible) + """ + unnormalized_values = {} + + # Get sin/cos embedding keys if enabled + sin_cos_keys = None + if self.apply_sincos_state_encoding: + state_config = self.modality_configs[embodiment_tag].get("state") + if state_config and hasattr(state_config, "sin_cos_embedding_keys"): + sin_cos_keys = state_config.sin_cos_embedding_keys + + for joint_group in self.modality_configs[embodiment_tag]["state"].modality_keys: + if joint_group not in state: + raise KeyError(f"Joint group '{joint_group}' not found in state dict for embodiment '{embodiment_tag}'") + + # Sin/cos encoding is not reversible + if sin_cos_keys and joint_group in sin_cos_keys: + raise ValueError( + f"Cannot unapply sin/cos encoding for joint group '{joint_group}' " + f"in embodiment '{embodiment_tag}'. This transformation is not reversible." + ) + + # Reverse mean/std normalization + elif ( + hasattr( + self.modality_configs[embodiment_tag]["state"], + "mean_std_embedding_keys", + ) + and self.modality_configs[embodiment_tag]["state"].mean_std_embedding_keys + and joint_group in self.modality_configs[embodiment_tag]["state"].mean_std_embedding_keys + ): + params = self.norm_params[embodiment_tag]["state"][joint_group] + unnormalized = unnormalize_values_meanstd(state[joint_group], params) + unnormalized_values[joint_group] = unnormalized + + # Reverse min/max normalization + else: + params = self.norm_params[embodiment_tag]["state"][joint_group] + unnormalized_values[joint_group] = unnormalize_values_minmax(state[joint_group], params) + + return unnormalized_values + + def apply_action( + self, + action: dict[str, np.ndarray], + embodiment_tag: str, + state: dict[str, np.ndarray] | None = None, + ) -> dict[str, np.ndarray]: + """ + Apply action processing (absolute->relative conversion, normalization). + + Processing order: + 1. Convert absolute actions to relative (if configured) + 2. Normalize actions + + Args: + action: Dict mapping joint_group -> raw action values + Shape per group: (T, D) where T is action horizon, D is action dimension + embodiment_tag: Embodiment identifier + state: Optional dict mapping joint_group -> raw state values + Required if any action group uses ActionRepresentation.RELATIVE + Shape per group: (T_state, D) where last timestep is used as reference + + Returns: + Dict mapping joint_group -> processed action values + Shape per group: (T, D) + + Raises: + ValueError: If state is None but required for relative action conversion + """ + action = deepcopy(action) # Avoid modifying input + + # Step 1: Convert absolute actions to relative (if needed) + modality_keys = self.modality_configs[embodiment_tag]["action"].modality_keys + action_configs = self.modality_configs[embodiment_tag]["action"].action_configs + + if action_configs is not None: + for key, action_config in zip(modality_keys, action_configs): + if action_config.rep == ActionRepresentation.RELATIVE and self.use_relative_action: + if state is None: + raise ValueError( + f"State dict required for relative action processing of key '{key}' " + f"in embodiment '{embodiment_tag}'" + ) + + # Determine which state key to use as reference + state_key = action_config.state_key if action_config.state_key else key + + if state_key not in state: + raise KeyError( + f"Reference state key '{state_key}' not found in state dict " + f"for embodiment '{embodiment_tag}'" + ) + + # Use last state as reference frame + reference_state = state[state_key][-1] + + # Convert absolute to relative + action[key] = self._convert_to_relative_action( + action=action[key], + reference_state=reference_state, + action_type=action_config.type, + action_format=action_config.format, + ) + + # Step 2: Normalize actions + normalized_values = {} + for joint_group in modality_keys: + if joint_group not in action: + raise KeyError( + f"Joint group '{joint_group}' not found in action dict for embodiment '{embodiment_tag}'" + ) + + params = self.norm_params[embodiment_tag]["action"][joint_group] + if ( + self.modality_configs[embodiment_tag]["action"].mean_std_embedding_keys is not None + and joint_group in self.modality_configs[embodiment_tag]["action"].mean_std_embedding_keys + ): + normalized = normalize_values_meanstd(action[joint_group], params) + else: + normalized = normalize_values_minmax(action[joint_group], params) + + if self.clip_outliers: + normalized = np.clip(normalized, -1.0, 1.0) + + normalized_values[joint_group] = normalized + + return normalized_values + + def unapply_action( + self, + action: dict[str, np.ndarray], + embodiment_tag: str, + state: dict[str, np.ndarray] | None = None, + ) -> dict[str, np.ndarray]: + """ + Reverse action processing (denormalization, relative->absolute conversion). + + Processing order: + 1. Denormalize actions + 2. Convert relative actions to absolute (if configured) + + Args: + action: Dict mapping joint_group -> processed action values + Shape per group: (T, D) or (B, T, D) for batched + embodiment_tag: Embodiment identifier + state: Optional dict mapping joint_group -> raw state values + Required if any action group uses ActionRepresentation.RELATIVE + Shape per group: (T_state, D) or (B, T_state, D) for batched + + Returns: + Dict mapping joint_group -> raw absolute action values + Shape per group: (T, D) or (B, T, D) for batched + + Raises: + ValueError: If state is None but required for relative->absolute conversion + """ + # Step 1: Unnormalize actions + unnormalized_values = {} + modality_keys = self.modality_configs[embodiment_tag]["action"].modality_keys + + for joint_group in modality_keys: + if joint_group not in action: + raise KeyError( + f"Joint group '{joint_group}' not found in action dict for embodiment '{embodiment_tag}'" + ) + + params = self.norm_params[embodiment_tag]["action"][joint_group] + group_values = action[joint_group] + + if ( + self.modality_configs[embodiment_tag]["action"].mean_std_embedding_keys is not None + and joint_group in self.modality_configs[embodiment_tag]["action"].mean_std_embedding_keys + ): + unnormalized = unnormalize_values_meanstd(group_values, params) + else: + unnormalized = unnormalize_values_minmax(group_values, params) + + unnormalized_values[joint_group] = unnormalized + + # Step 2: Convert relative actions to absolute (if needed) + action_configs = self.modality_configs[embodiment_tag]["action"].action_configs + + if action_configs is not None: + for key, action_config in zip(modality_keys, action_configs): + if action_config.rep == ActionRepresentation.RELATIVE and self.use_relative_action: + if state is None: + raise ValueError( + f"State dict required for relative->absolute conversion of key '{key}' " + f"in embodiment '{embodiment_tag}'" + ) + + # Determine which state key to use as reference + state_key = action_config.state_key if action_config.state_key else key + + if state_key not in state: + raise KeyError( + f"Reference state key '{state_key}' not found in state dict " + f"for embodiment '{embodiment_tag}'" + ) + + relative_action = unnormalized_values[key] + + # Handle batched and unbatched cases + is_batched = relative_action.ndim == 3 + if not is_batched: + assert relative_action.ndim == 2 + reference_state = state[state_key] + if reference_state.ndim == 2: + reference_state = reference_state[None, :] + relative_action = relative_action[None, :] + else: + reference_state = state[state_key] + if reference_state.ndim == 2: + reference_state = reference_state[None, :] + + # Convert batched relative actions to absolute + absolute_actions = [] + for s, a in zip(reference_state, relative_action): + # Use last timestep of state as reference + absolute_action = self._convert_to_absolute_action( + action=a, + reference_state=s[-1], + action_type=action_config.type, + action_format=action_config.format, + ) + absolute_actions.append(absolute_action) + + if is_batched: + unnormalized_values[key] = np.stack(absolute_actions, axis=0) + else: + unnormalized_values[key] = absolute_actions[0] + + return unnormalized_values + + def apply( + self, + state: dict[str, np.ndarray], + action: dict[str, np.ndarray], + embodiment_tag: str, + ) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: + """ + Apply both state and action processing together. + + Convenience method that processes state and action in one call, + automatically passing raw state to action processor for relative conversion. + + Args: + state: Dict mapping joint_group -> raw state values + action: Dict mapping joint_group -> raw action values + embodiment_tag: Embodiment identifier + + Returns: + Tuple of (processed_state, processed_action) + """ + processed_state = self.apply_state(state, embodiment_tag) + if action: + processed_action = self.apply_action(action, embodiment_tag, state=state) + else: + assert not self.training, "Action is required in training mode" + processed_action = {} + return processed_state, processed_action + + def unapply( + self, + state: dict[str, np.ndarray], + action: dict[str, np.ndarray], + embodiment_tag: str, + raw_state: dict[str, np.ndarray] | None = None, + ) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: + """ + Reverse both state and action processing together. + + Args: + state: Dict mapping joint_group -> processed state values + action: Dict mapping joint_group -> processed action values + embodiment_tag: Embodiment identifier + raw_state: Optional dict of raw states for relative->absolute conversion + If None, will use unapplied state (but won't work for sin/cos encoded states) + + Returns: + Tuple of (raw_state, raw_action) + """ + # Unapply state first + try: + unapplied_state = self.unapply_state(state, embodiment_tag) + except ValueError as e: + if "sin/cos encoding" in str(e) and raw_state is None: + raise ValueError("Cannot unapply sin/cos encoded state. Please provide raw_state parameter.") from e + raise + + # Use provided raw_state if available, otherwise use unapplied state + state_for_action = raw_state if raw_state is not None else unapplied_state + + # Unapply action + unapplied_action = self.unapply_action(action, embodiment_tag, state=state_for_action) + + return unapplied_state, unapplied_action + + def get_state_dim(self, embodiment_tag: str, include_sincos_expansion: bool = False) -> int: + """ + Get total state dimension after processing. + + Args: + embodiment_tag: Embodiment identifier + include_sincos_expansion: If True, accounts for sin/cos encoding doubling dimensions + + Returns: + Total state dimension across all joint groups + """ + total_dim = 0 + state_config = self.modality_configs[embodiment_tag]["state"] + + # Get sin/cos embedding keys if enabled + sin_cos_keys = set() + if self.apply_sincos_state_encoding and hasattr(state_config, "sin_cos_embedding_keys"): + sin_cos_keys = set(state_config.sin_cos_embedding_keys) + + for joint_group in state_config.modality_keys: + base_dim = self.norm_params[embodiment_tag]["state"][joint_group]["dim"].item() + + # Sin/cos encoding doubles the dimension + if include_sincos_expansion and joint_group in sin_cos_keys: + total_dim += base_dim * 2 + else: + total_dim += base_dim + + return total_dim + + def get_action_dim(self, embodiment_tag: str) -> int: + """ + Get total action dimension. + + Args: + embodiment_tag: Embodiment identifier + + Returns: + Total action dimension across all joint groups + """ + total_dim = 0 + for joint_group in self.modality_configs[embodiment_tag]["action"].modality_keys: + total_dim += self.norm_params[embodiment_tag]["action"][joint_group]["dim"].item() + return total_dim + + def _convert_to_relative_action( + self, + action: np.ndarray, + reference_state: np.ndarray, + action_type: ActionType, + action_format: ActionFormat, + ) -> np.ndarray: + """Convert absolute action to relative action using reference state.""" + assert action.ndim == 2, f"Expected action shape (T, D), got {action.shape}" + assert reference_state.ndim == 1, f"Expected state shape (D,), got {reference_state.shape}" + + if action_type == ActionType.EEF: + action_chunking = EndEffectorActionChunk.from_array(action, action_format) + reference_frame = EndEffectorPose.from_action_format(reference_state, action_format) + + elif action_type == ActionType.NON_EEF: + action_chunking = JointActionChunk([JointPose(m) for m in action]) + reference_frame = JointPose(reference_state) + + else: + raise ValueError(f"Unknown ActionType: {action_type}") + + relative_action_chunking = action_chunking.relative_chunking(reference_frame=reference_frame) + return relative_action_chunking.to(action_format) + + def _convert_to_absolute_action( + self, + action: np.ndarray, + reference_state: np.ndarray, + action_type: ActionType, + action_format: ActionFormat, + ) -> np.ndarray: + """Convert relative action to absolute action using reference state.""" + assert action.ndim == 2, f"Expected action shape (T, D), got {action.shape}" + assert reference_state.ndim == 1, f"Expected state shape (D,), got {reference_state.shape}" + assert reference_state.shape[0] == action.shape[1], ( + f"State dim {reference_state.shape[0]} != action dim {action.shape[1]}" + ) + + if action_type == ActionType.EEF: + rel_action = EndEffectorActionChunk.from_array(action, action_format) + reference_frame = EndEffectorPose.from_action_format(reference_state, action_format) + + elif action_type == ActionType.NON_EEF: + rel_action = JointActionChunk([JointPose(pose) for pose in action]) + reference_frame = JointPose(reference_state) + + else: + raise ValueError(f"Unknown ActionType: {action_type}") + + abs_action = rel_action.to_absolute_chunking(reference_frame=reference_frame) + return abs_action.to(action_format) + + def __str__(self) -> str: + return ( + "StateActionProcessor(" + f"modality_configs={self.modality_configs}, " + f"statistics={self.statistics}, " + f"use_percentiles={self.use_percentiles}, " + f"clip_outliers={self.clip_outliers}, " + f"apply_sincos_state_encoding={self.apply_sincos_state_encoding}, " + f"use_relative_action={self.use_relative_action})" + ) diff --git a/vllm_omni/diffusion/models/gr00t/dataio/types.py b/vllm_omni/diffusion/models/gr00t/dataio/types.py new file mode 100644 index 00000000000..09e32e0a7e5 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/types.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +import numpy as np + +from vllm_omni.diffusion.models.gr00t.dataio.embodiment_tags import EmbodimentTag + + +class MessageType(Enum): + START_OF_EPISODE = "start_of_episode" + END_OF_EPISODE = "end_of_episode" + EPISODE_STEP = "episode_step" + IMAGE = "image" + TEXT = "text" + + +class ActionRepresentation(Enum): + RELATIVE = "relative" + DELTA = "delta" + ABSOLUTE = "absolute" + + +class ActionType(Enum): + EEF = "eef" + NON_EEF = "non_eef" + + +class ActionFormat(Enum): + DEFAULT = "default" + XYZ_ROT6D = "xyz+rot6d" + XYZ_ROTVEC = "xyz+rotvec" + + +@dataclass +class VLAStepData: + """ + Represents a single step of VLA (Vision-Language-Action) data. + + This is the core data structure returned by datasets, containing raw observation + and action data that will be processed by the SequenceVLAProcessor. + """ + + # Core data + images: dict[str, list[np.ndarray]] # view_name -> list[np.ndarray] (for temporal stacking) + states: dict[str, np.ndarray] # state_name -> np.ndarray (dim,) for single step or (horizon, dim) for trajectory + actions: dict[str, np.ndarray] # action_name -> np.ndarray (horizon, dim) for action chunk + masks: dict[str, list[np.ndarray]] | None = None # view_name -> list[np.ndarray] (H, W) + text: str | None = None # Optional task description or instruction + embodiment: EmbodimentTag = EmbodimentTag.NEW_EMBODIMENT # Optional embodiment tag for cross-embodiment training + is_demonstration: bool = ( + False # Whether the step is a demonstration. If True, no loss should be computed for this step. + ) + + # Flexible metadata that can be extended by users + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ActionConfig: + rep: ActionRepresentation + type: ActionType + format: ActionFormat + state_key: str | None = None + + +@dataclass +class ModalityConfig: + """Configuration for a modality defining how data should be sampled and loaded. + + This class specifies which indices to sample relative to a base index and which + keys to load for a particular modality (e.g., video, state, action). + """ + + delta_indices: list[int] + """Delta indices to sample relative to the current index. The returned data will + correspond to the original data at a sampled base index + delta indices.""" + modality_keys: list[str] + """The keys to load for the modality in the dataset.""" + sin_cos_embedding_keys: list[str] | None = None + """Optional list of keys to apply sin/cos encoding. If None or empty, use + min/max normalization for all keys.""" + mean_std_embedding_keys: list[str] | None = None + """Optional list of keys to apply mean/std normalization. If None or empty, + use min/max normalization for all keys.""" + action_configs: list[ActionConfig] | None = None + + def __post_init__(self): + """Validate fields and set default values.""" + if self.delta_indices is None or not isinstance(self.delta_indices, list): + raise ValueError(f"delta_indices must be a non-None list, got {self.delta_indices!r}") + if self.modality_keys is None or not isinstance(self.modality_keys, list) or len(self.modality_keys) == 0: + raise ValueError(f"modality_keys must be a non-empty list, got {self.modality_keys!r}") + if self.action_configs is not None: + assert len(self.action_configs) == len(self.modality_keys), ( + f"Number of action configs ({len(self.action_configs)}) must match " + f"number of modality keys ({len(self.modality_keys)})" + ) + parsed_action_configs = [] + for action_config in self.action_configs: + if isinstance(action_config, dict): + action_config = ActionConfig( + rep=ActionRepresentation[action_config["rep"]], + type=ActionType[action_config["type"]], + format=ActionFormat[action_config["format"]], + state_key=action_config.get("state_key", None), + ) + parsed_action_configs.append(action_config) + self.action_configs = parsed_action_configs diff --git a/vllm_omni/diffusion/models/gr00t/dataio/utils.py b/vllm_omni/diffusion/models/gr00t/dataio/utils.py new file mode 100644 index 00000000000..97ee660dfff --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/dataio/utils.py @@ -0,0 +1,303 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import asdict, is_dataclass +from enum import Enum +from typing import Any + +import numpy as np + +from vllm_omni.diffusion.models.gr00t.configs.embodiment.embodiment_configs import ModalityConfig + + +def apply_sin_cos_encoding(values: np.ndarray) -> np.ndarray: + """Apply sin/cos encoding to values. + + Args: + values: Array of shape (..., D) containing values to encode + + Returns: + Array of shape (..., 2*D) with [sin, cos] concatenated + + Note: This DOUBLES the dimension. For example: + Input: [v₁, v₂, v₃] with shape (..., 3) + Output: [sin(v₁), sin(v₂), sin(v₃), cos(v₁), cos(v₂), cos(v₃)] with shape (..., 6) + """ + sin_values = np.sin(values) + cos_values = np.cos(values) + # Concatenate sin and cos: [sin(v1), sin(v2), ..., cos(v1), cos(v2), ...] + return np.concatenate([sin_values, cos_values], axis=-1) + + +def nested_dict_to_numpy(data): + """ + Recursively converts bottom-level list of lists to NumPy arrays. + + Args: + data: A nested dictionary where bottom nodes are list of lists, + and parent nodes are strings (keys) + + Returns: + The same dictionary structure with bottom-level lists converted to NumPy arrays + + Example: + >>> data = {"a": {"b": [[0, 1], [2, 3]]}} + >>> result = nested_dict_to_numpy(data) + >>> print(result["a"]["b"]) + [[0 1] + [2 3]] + """ + if isinstance(data, dict): + return {key: nested_dict_to_numpy(value) for key, value in data.items()} + elif isinstance(data, list): + # Convert lists to numpy arrays + # NumPy will handle both 1D and 2D cases appropriately + return np.array(data) + else: + return data + + +def normalize_values_minmax(values, params): + """ + Normalize values using min-max normalization to [-1, 1] range. + + Args: + values: Input values to normalize + - Shape: (T, D) or (B, T, D) where B is batch, T is time/step, D is feature dimension + - Can handle 2D or 3D arrays where last axis represents features + params: Dictionary with "min" and "max" keys + - params["min"]: Minimum values for normalization + * Case 1 - 1D bounds: Shape (D,) - same min/max for all steps + * Case 2 - 2D bounds: Shape (T, D) - different min/max per step + - params["max"]: Maximum values for normalization + * Case 1 - 1D bounds: Shape (D,) - same min/max for all steps + * Case 2 - 2D bounds: Shape (T, D) - different min/max per step + joint_group: Optional indexing for joint groups (legacy parameter) + + Returns: + Normalized values in [-1, 1] range + - Same shape as input values: (T, D) or (B, T, D) + - Values are linearly mapped from [min, max] to [-1, 1] + - For features where min == max, normalized value is 0 + + Examples: + # 1D bounds - same normalization for all steps + values: (10, 5), params["min"]: (5,), params["max"]: (5,) + + # 2D bounds - per-step normalization + values: (8, 4), params["min"]: (8, 4), params["max"]: (8, 4) + """ + min_vals = params["min"] + max_vals = params["max"] + normalized = np.zeros_like(values) + + mask = ~np.isclose(max_vals, min_vals) + + normalized[..., mask] = (values[..., mask] - min_vals[..., mask]) / (max_vals[..., mask] - min_vals[..., mask]) + normalized[..., mask] = 2 * normalized[..., mask] - 1 + + return normalized + + +def unnormalize_values_minmax(normalized_values, params): + """ + Min-max unnormalization from [-1, 1] range back to original range. + + Args: + normalized_values: Normalized input values in [-1, 1] range + - Shape: (T, D) or (B, T, D) where B is batch, T is time/step, D is feature dimension + - Values outside [-1, 1] are automatically clipped + params: Dictionary with "min" and "max" keys + - params["min"]: Original minimum values used for normalization + * Case 1 - 1D bounds: Shape (D,) - same min/max for all steps + * Case 2 - 2D bounds: Shape (T, D) - different min/max per step + - params["max"]: Original maximum values used for normalization + * Case 1 - 1D bounds: Shape (D,) - same min/max for all steps + * Case 2 - 2D bounds: Shape (T, D) - different min/max per step + + Returns: + Unnormalized values in original range [min, max] + - Same shape as input normalized_values: (T, D) or (B, T, D) + - Values are linearly mapped from [-1, 1] back to [min, max] + - Input values are clipped to [-1, 1] before unnormalization + + Examples: + # 1D bounds - same unnormalization for all steps + normalized_values: (10, 5), params["min"]: (5,), params["max"]: (5,) + + # 2D bounds - per-step unnormalization + normalized_values: (8, 4), params["min"]: (8, 4), params["max"]: (8, 4) + """ + + min_vals = params["min"] + max_vals = params["max"] + range_vals = max_vals - min_vals + + # Unnormalize from [-1, 1] + unnormalized = (np.clip(normalized_values, -1.0, 1.0) + 1.0) / 2.0 * range_vals + min_vals + return unnormalized + + +def normalize_values_meanstd(values, params): + """ + Normalize values using mean-std (z-score) normalization. + + Args: + values: Input values to normalize + - Shape: (T, D) or (B, T, D) where B is batch, T is time/step, D is feature dimension + - Can handle 2D or 3D arrays where last axis represents features + params: Dictionary with "mean" and "std" keys + - params["mean"]: Mean values for normalization + * Case 1 - 1D params: Shape (D,) - same mean for all steps + * Case 2 - 2D params: Shape (T, D) - different mean per step + - params["std"]: Standard deviation values for normalization + * Case 1 - 1D params: Shape (D,) - same std for all steps + * Case 2 - 2D params: Shape (T, D) - different std per step + + Returns: + Normalized values using z-score normalization + - Same shape as input values: (T, D) or (B, T, D) + - Values are transformed as: (x - mean) / std + - For features where std == 0, normalized value equals original value + + Examples: + # 1D params - same normalization for all steps + values: (10, 5), params["mean"]: (5,), params["std"]: (5,) + + # 2D params - per-step normalization + values: (8, 4), params["mean"]: (8, 4), params["std"]: (8, 4) + """ + mean_vals = params["mean"] + std_vals = params["std"] + + # Create mask for non-zero standard deviations + mask = std_vals != 0 + + # Initialize normalized array + normalized = np.zeros_like(values) + + # Normalize only features with non-zero std + normalized[..., mask] = (values[..., mask] - mean_vals[..., mask]) / std_vals[..., mask] + + # Keep original values for zero-std features + normalized[..., ~mask] = values[..., ~mask] + + return normalized + + +def unnormalize_values_meanstd(normalized_values, params): + """ + Mean-std unnormalization (reverse z-score normalization). + + Args: + normalized_values: Normalized input values (z-scores) + - Shape: (T, D) or (B, T, D) where B is batch, T is time/step, D is feature dimension + - Can handle 2D or 3D arrays where last axis represents features + params: Dictionary with "mean" and "std" keys + - params["mean"]: Original mean values used for normalization + * Case 1 - 1D params: Shape (D,) - same mean for all steps + * Case 2 - 2D params: Shape (T, D) - different mean per step + - params["std"]: Original standard deviation values used for normalization + * Case 1 - 1D params: Shape (D,) - same std for all steps + * Case 2 - 2D params: Shape (T, D) - different std per step + + Returns: + Unnormalized values in original scale + - Same shape as input normalized_values: (T, D) or (B, T, D) + - Values are transformed as: x * std + mean + - For features where std == 0, unnormalized value equals normalized value + + Examples: + # 1D params - same unnormalization for all steps + normalized_values: (10, 5), params["mean"]: (5,), params["std"]: (5,) + + # 2D params - per-step unnormalization + normalized_values: (8, 4), params["mean"]: (8, 4), params["std"]: (8, 4) + """ + mean_vals = params["mean"] + std_vals = params["std"] + + # Create mask for non-zero standard deviations + mask = std_vals != 0 + + # Initialize unnormalized array + unnormalized = np.zeros_like(normalized_values) + + # Unnormalize only features with non-zero std + unnormalized[..., mask] = normalized_values[..., mask] * std_vals[..., mask] + mean_vals[..., mask] + + # Keep normalized values for zero-std features + unnormalized[..., ~mask] = normalized_values[..., ~mask] + + return unnormalized + + +def to_json_serializable(obj: Any) -> Any: + """ + Recursively convert dataclasses and numpy arrays to JSON-serializable format. + + Args: + obj: Object to convert (can be dataclass, numpy array, dict, list, etc.) + + Returns: + JSON-serializable representation of the object + """ + if is_dataclass(obj) and not isinstance(obj, type): + # Convert dataclass to dict, then recursively process the dict + return to_json_serializable(asdict(obj)) + elif isinstance(obj, np.ndarray): + # Convert numpy array to list + return obj.tolist() + elif isinstance(obj, np.integer): + # Convert numpy integers to Python int + return int(obj) + elif isinstance(obj, np.floating): + # Convert numpy floats to Python float + return float(obj) + elif isinstance(obj, np.bool_): + # Convert numpy bool to Python bool + return bool(obj) + elif isinstance(obj, dict): + # Recursively process dictionary values + return {key: to_json_serializable(value) for key, value in obj.items()} + elif isinstance(obj, (list, tuple)): + # Recursively process list/tuple elements + return [to_json_serializable(item) for item in obj] + elif isinstance(obj, set): + # Convert set to list + return [to_json_serializable(item) for item in obj] + elif isinstance(obj, (str, int, float, bool, type(None))): + # Already JSON-serializable + return obj + elif isinstance(obj, Enum): + return obj.name + else: + # For other types, try to convert to string as fallback + # You might want to handle specific types differently + return str(obj) + + +def parse_modality_configs( + modality_configs: dict[str, dict[str, ModalityConfig]], +) -> dict[str, dict[str, ModalityConfig]]: + parsed_modality_configs = {} + for embodiment_tag, modality_config in modality_configs.items(): + parsed_modality_configs[embodiment_tag] = {} + for modality, config in modality_config.items(): + if isinstance(config, dict): + parsed_modality_configs[embodiment_tag][modality] = ModalityConfig(**config) + else: + parsed_modality_configs[embodiment_tag][modality] = config + return parsed_modality_configs diff --git a/vllm_omni/diffusion/models/gr00t/modeling/__init__.py b/vllm_omni/diffusion/models/gr00t/modeling/__init__.py new file mode 100644 index 00000000000..7aecb405ff9 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/modeling/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.diffusion.models.gr00t.modeling.gr00t_n1d7 import Gr00tN1d7 +from vllm_omni.diffusion.models.gr00t.modeling.processing_gr00t_n1d7 import Gr00tN1d7Processor + +__all__ = ["Gr00tN1d7", "Gr00tN1d7Processor"] diff --git a/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py new file mode 100644 index 00000000000..4a84d6eceb5 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py @@ -0,0 +1,776 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any + +import torch +import torch.nn.functional as F +from torch import nn +from torch.distributions import Beta +from transformers import AutoConfig, AutoModel, PreTrainedModel +from transformers.feature_extraction_utils import BatchFeature +from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig + +from vllm_omni.diffusion.models.gr00t.configs.model.gr00t_n1d7 import Gr00tN1d7Config +from vllm_omni.diffusion.models.gr00t.modeling.modules.dit import AlternateVLDiT, DiT, SelfAttentionTransformer +from vllm_omni.diffusion.models.gr00t.modeling.modules.embodiment_conditioned_mlp import ( + CategorySpecificMLP, + MultiEmbodimentActionEncoder, +) +from vllm_omni.diffusion.models.internvla_a1.adapter_qwen3_vl import Qwen3VLForConditionalGeneration + +logger = logging.getLogger(__name__) + + +def _make_qwen3_vl_2b_config( + *, + backbone_embedding_dim: int, + num_hidden_layers: int, + attn_implementation: str, +) -> Qwen3VLConfig: + if backbone_embedding_dim != 2048: + raise ValueError(f"GR00T N1.7 expects a 2048-dim Qwen3-VL backbone, got {backbone_embedding_dim}") + + config = Qwen3VLConfig( + text_config={ + "vocab_size": 151936, + "hidden_size": backbone_embedding_dim, + "intermediate_size": 6144, + "num_hidden_layers": num_hidden_layers, + "num_attention_heads": 16, + "num_key_value_heads": 8, + "head_dim": 128, + "max_position_embeddings": 262144, + "rope_theta": 5000000, + "rope_scaling": { + "mrope_interleaved": True, + "mrope_section": [24, 20, 20], + "rope_type": "default", + }, + "tie_word_embeddings": True, + "bos_token_id": 151643, + "eos_token_id": 151645, + "dtype": "bfloat16", + }, + vision_config={ + "depth": 24, + "hidden_size": 1024, + "intermediate_size": 4096, + "num_heads": 16, + "out_hidden_size": backbone_embedding_dim, + "deepstack_visual_indexes": [5, 11, 17], + "patch_size": 16, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + "num_position_embeddings": 2304, + "in_channels": 3, + }, + image_token_id=151655, + video_token_id=151656, + vision_start_token_id=151652, + vision_end_token_id=151653, + tie_word_embeddings=False, + ) + config._attn_implementation = attn_implementation + return config + + +class Gr00tN1d7ActionHead(nn.Module): + """Action head component for flow matching diffusion policy.""" + + supports_gradient_checkpointing = True + + def __init__(self, config: Gr00tN1d7Config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.input_embedding_dim = config.input_embedding_dim + + if config.use_alternate_vl_dit: + self.model = AlternateVLDiT( + **config.diffusion_model_cfg, + cross_attention_dim=config.backbone_embedding_dim, + attend_text_every_n_blocks=config.attend_text_every_n_blocks, + ) + logger.info("Using AlternateVLDiT for diffusion model") + else: + self.model = DiT( + **config.diffusion_model_cfg, + cross_attention_dim=config.backbone_embedding_dim, + ) + logger.info("Using DiT for diffusion model") + self.action_dim = config.max_action_dim + self.action_horizon = config.action_horizon + self.num_inference_timesteps = config.num_inference_timesteps + + self.state_encoder = CategorySpecificMLP( + num_categories=config.max_num_embodiments, + input_dim=config.max_state_dim * config.state_history_length, + hidden_dim=self.hidden_size, + output_dim=self.input_embedding_dim, + ) + self.action_encoder = MultiEmbodimentActionEncoder( + action_dim=self.action_dim, + hidden_size=self.input_embedding_dim, + num_embodiments=config.max_num_embodiments, + ) + self.action_decoder = CategorySpecificMLP( + num_categories=config.max_num_embodiments, + input_dim=self.hidden_size, + hidden_dim=self.hidden_size, + output_dim=self.action_dim, + ) + + self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity() + + vl_self_attention_cfg = getattr(config, "vl_self_attention_cfg", None) + if vl_self_attention_cfg and vl_self_attention_cfg.get("num_layers", 0) > 0: + self.vl_self_attention = SelfAttentionTransformer(**vl_self_attention_cfg) + else: + self.vl_self_attention = nn.Identity() + + if config.add_pos_embed: + self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim) + nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) + + # State dropout parameters + self.state_dropout_prob = config.state_dropout_prob + + self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta) + self.num_timestep_buckets = config.num_timestep_buckets + self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model, config.tune_vlln) + + def set_trainable_parameters(self, tune_projector: bool, tune_diffusion_model: bool, tune_vlln: bool): + self.tune_projector = tune_projector + self.tune_diffusion_model = tune_diffusion_model + self.tune_vlln = tune_vlln + for p in self.parameters(): + p.requires_grad = True + if not tune_projector: + self.state_encoder.requires_grad_(False) + self.action_encoder.requires_grad_(False) + self.action_decoder.requires_grad_(False) + if self.config.add_pos_embed: + self.position_embedding.requires_grad_(False) + if not tune_diffusion_model: + self.model.requires_grad_(False) + if not tune_vlln: + self.vlln.requires_grad_(False) + self.vl_self_attention.requires_grad_(False) + logger.debug(f"Tune action head projector: {self.tune_projector}") + logger.debug(f"Tune action head diffusion model: {self.tune_diffusion_model}") + logger.debug(f"Tune action head vlln: {self.tune_vlln}") + # Check if any parameters are still trainable. If not, log a warning. + if not tune_projector and not tune_diffusion_model and not tune_vlln: + for name, p in self.named_parameters(): + if p.requires_grad: + logger.debug(f"Action head trainable parameter: {name}") + if not any(p.requires_grad for p in self.parameters()): + logger.warning("No action head trainable parameters found.") + + def set_frozen_modules_to_eval_mode(self): + """ + Huggingface will call model.train() at each training_step. To ensure + the expected behaviors for modules like dropout, batchnorm, etc., we + need to call model.eval() for the frozen modules. + """ + if self.training: + if not self.tune_projector: + self.state_encoder.eval() + self.action_encoder.eval() + self.action_decoder.eval() + if self.config.add_pos_embed: + self.position_embedding.eval() + if not self.tune_diffusion_model: + self.model.eval() + if not self.tune_vlln: + self.vlln.eval() + self.vl_self_attention.eval() + + def sample_time(self, batch_size, device, dtype): + sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype) + sample = (1 - sample) * self.config.noise_s + return sample + + def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature: + backbone_features = backbone_output["backbone_features"] + backbone_features = self.vlln(backbone_features) + backbone_features = self.vl_self_attention(backbone_features) + backbone_output["backbone_features"] = backbone_features + return backbone_output + + def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature: + """ + Forward pass through the action head. + + Args: + backbone_output: Output from the backbone model containing: + - backbone_features: [B, seq_len, backbone_embedding_dim] + - backbone_attention_mask: [B, seq_len] + action_input: Input containing: + - state: [B, state_dim] + - action: [B, action_horizon, action_dim] (during training) + - embodiment_id: [B] (embodiment IDs) + - action_mask: [B, action_horizon, action_dim] + + Returns: + BatchFeature containing: + - loss: action prediction loss + """ + # Set frozen modules to eval + self.set_frozen_modules_to_eval_mode() + + backbone_output = self.process_backbone_output(backbone_output) + + # Get vision and language embeddings. + vl_embeds = backbone_output.backbone_features + device = vl_embeds.device + + # Get embodiment ID. + embodiment_id = action_input.embodiment_id + + # Handle state history + assert action_input.state.shape[1] == self.config.state_history_length + action_input.state = action_input.state.view(action_input.state.shape[0], 1, -1) + + # Embed state. + state_features = self.state_encoder(action_input.state, embodiment_id) + + # Dropout state features (training only): zero out dropped states. + if self.training and self.state_dropout_prob > 0: + do_dropout = torch.rand(state_features.shape[0], device=state_features.device) < self.state_dropout_prob + do_dropout = do_dropout[:, None, None].to(dtype=state_features.dtype) + state_features = state_features * (1 - do_dropout) + + # Embed noised action trajectory. + actions = action_input.action + noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype) + t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype) + t = t[:, None, None] # shape (B,1,1) for broadcast + + noisy_trajectory = (1 - t) * noise + t * actions + velocity = actions - noise + + # Convert (continuous) t -> discrete if needed + t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long() + action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id) + + # Maybe add position embedding. + if self.config.add_pos_embed: + pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) + pos_embs = self.position_embedding(pos_ids).unsqueeze(0) + action_features = action_features + pos_embs + + # Join vision, language, state and action embedding along sequence dimension. + sa_embs = torch.cat((state_features, action_features), dim=1) + vl_attn_mask = backbone_output.backbone_attention_mask + + if self.config.use_alternate_vl_dit: + image_mask = backbone_output.image_mask + backbone_attention_mask = backbone_output.backbone_attention_mask + model_output, _ = self.model( + hidden_states=sa_embs, + encoder_hidden_states=vl_embeds, + encoder_attention_mask=vl_attn_mask, + timestep=t_discretized, + return_all_hidden_states=True, + image_mask=image_mask, + backbone_attention_mask=backbone_attention_mask, + ) + else: + model_output, _ = self.model( + hidden_states=sa_embs, + encoder_hidden_states=vl_embeds, + encoder_attention_mask=vl_attn_mask, + timestep=t_discretized, + return_all_hidden_states=True, + ) + + pred = self.action_decoder(model_output, embodiment_id) + pred_actions = pred[:, -actions.shape[1] :] + + # Slice out only the action portion of pred and target. + action_mask = action_input.action_mask + action_loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask + loss = action_loss.sum() / (action_mask.sum() + 1e-6) + + return { + "loss": loss, + "action_loss": action_loss, + "action_mask": action_mask, + "backbone_features": vl_embeds, + "state_features": state_features, + } + + def _encode_features(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature: + """ + Encode features for the action head. + + Args: + backbone_output: Output from the backbone model containing: + - backbone_features: [B, seq_len, backbone_embedding_dim] + - backbone_attention_mask: [B, seq_len] + action_input: Input containing: + - state: [B, state_history_length, max_state_dim] + - embodiment_id: [B] (embodiment IDs) + + Returns: + BatchFeature containing: + - backbone_features: [B, seq_len, backbone_embedding_dim] + - state_features: [B, 1, input_embedding_dim] + """ + backbone_output = self.process_backbone_output(backbone_output) + + # Get vision and language embeddings. + vl_embeds = backbone_output.backbone_features + embodiment_id = action_input.embodiment_id + + # Handle state history: if we have fewer timesteps than expected, repeat to fill + state = action_input.state + current_T = state.shape[1] + assert current_T == self.config.state_history_length, "current_T != state_history_length" + # Reshape state from [B, state_history_length, max_state_dim] to [B, 1, state_history_length * max_state_dim] + state = state.view(state.shape[0], 1, -1) + + # Embed state. + state_features = self.state_encoder(state, embodiment_id) + + return BatchFeature(data={"backbone_features": vl_embeds, "state_features": state_features}) + + @torch.no_grad() + def get_action_with_features( + self, + backbone_features: torch.Tensor, + state_features: torch.Tensor, + embodiment_id: torch.Tensor, + backbone_output: BatchFeature, + action_input: BatchFeature, + options: dict[str, Any] | None = None, + ) -> BatchFeature: + """ + Generate actions using the flow matching diffusion process. + + Args: + backbone_features: [B, seq_len, backbone_embedding_dim] + state_features: [B, state_horizon, input_embedding_dim] + embodiment_id: [B] (embodiment IDs) + backbone_output: Output from the backbone model + """ + vl_embeds = backbone_features + + # Set initial actions as the sampled noise. + batch_size = vl_embeds.shape[0] + device = vl_embeds.device + actions = torch.randn( + size=(batch_size, self.config.action_horizon, self.action_dim), + dtype=vl_embeds.dtype, + device=device, + ) + + dt = 1.0 / self.num_inference_timesteps + vel_strength = torch.ones_like(actions) + + if "action" in action_input: + # If action in input when doing get action, it means we want to use RTC. + # action_horizon is the action horizon of the input action. + # rtc_overlap_steps is the number of steps to overlap with the previous action chunks. + # rtc_frozen_steps is the policy inference latency expressed as frozen action steps. + # rtc_ramp_rate is the rate of the ramp of denoising the actions. + assert options is not None, "options is not None" + assert "action_horizon" in options, "action_horizon is not in options" + assert "rtc_overlap_steps" in options, "rtc_overlap_steps is not in options" + assert "rtc_frozen_steps" in options, "rtc_frozen_steps is not in options" + assert "rtc_ramp_rate" in options, "rtc_ramp_rate is not in options" + + action_horizon_before_padding = options["action_horizon"] + + # Use previous action instead of pure noise to do inpainting + actions[:, : options["rtc_overlap_steps"], :] = action_input["action"][ + :, + action_horizon_before_padding - options["rtc_overlap_steps"] : action_horizon_before_padding, + :, + ] + vel_strength[:, : options["rtc_frozen_steps"], :] = 0.0 + # NOTE: use an exponential ramp strength to set the remaining unfrozen rtc_steps + intermediate_steps = options["rtc_overlap_steps"] - options["rtc_frozen_steps"] + # Create exponential ramp from 0 to 1 over intermediate steps + t = torch.linspace(0.0, 1.0, intermediate_steps + 2, device=device) + ramp = 1 - torch.exp(-options["rtc_ramp_rate"] * t) + ramp = ramp / ramp[-1].clamp_min(1e-8) # normalize to [0,1] + ramp = ramp[1:-1] # we will only take the middle part of the ramp, ignore the 0.0 and 1.0 + # Apply ramp to the intermediate steps [batch, intermediate_steps, action_dim] + vel_strength[ + :, + options["rtc_frozen_steps"] : options["rtc_overlap_steps"], + :, + ] = ramp[None, :, None].to(device) + + # Run denoising steps. + for t in range(self.num_inference_timesteps): + t_discretized = t * self.num_timestep_buckets // self.num_inference_timesteps + + # Embed noised action trajectory. + timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device) + action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id) + # Add position embedding. + if self.config.add_pos_embed: + pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) + pos_embs = self.position_embedding(pos_ids).unsqueeze(0) + action_features = action_features + pos_embs + + # Join vision, language, state and action embedding along sequence dimension. + sa_embs = torch.cat((state_features, action_features), dim=1) + + # Run model forward. + if self.config.use_alternate_vl_dit: + model_output = self.model( + hidden_states=sa_embs, + encoder_hidden_states=vl_embeds, + timestep=timesteps_tensor, + image_mask=backbone_output.image_mask, + backbone_attention_mask=backbone_output.backbone_attention_mask, + ) + else: + model_output = self.model( + hidden_states=sa_embs, + encoder_hidden_states=vl_embeds, + timestep=timesteps_tensor, + ) + pred = self.action_decoder(model_output, embodiment_id) + + pred_velocity = pred[:, -self.action_horizon :] + + # Update actions using euler integration. + actions = actions + dt * pred_velocity * vel_strength + + return BatchFeature( + data={ + "action_pred": actions, + "backbone_features": vl_embeds, + "state_features": state_features, + } + ) + + @torch.no_grad() + def get_action( + self, + backbone_output: BatchFeature, + action_input: BatchFeature, + options: dict[str, Any] | None = None, + ) -> BatchFeature: + """ + Generate actions using the flow matching diffusion process. + + Args: + backbone_output: Output from the backbone model containing: + - backbone_features: [B, seq_len, backbone_embedding_dim] + - backbone_attention_mask: [B, seq_len] + action_input: Input containing: + - state: [B, state_dim] + - embodiment_id: [B] (embodiment IDs) + + Returns: + BatchFeature containing: + - action_pred: [B, action_horizon, action_dim] predicted actions + """ + features = self._encode_features(backbone_output, action_input) + return self.get_action_with_features( + backbone_features=features.backbone_features, + state_features=features.state_features, + embodiment_id=action_input.embodiment_id, + backbone_output=backbone_output, + action_input=action_input, + options=options, + ) + + @property + def device(self): + return next(iter(self.parameters())).device + + @property + def dtype(self): + return next(iter(self.parameters())).dtype + + def prepare_input(self, batch: dict) -> BatchFeature: + """Prepare input batch for the action head.""" + return BatchFeature(data=batch) + + +class _Qwen3VLBackbone(nn.Module): + """GR00T adapter around the shared Qwen3-VL implementation.""" + + def __init__( + self, + model_name: str, + tune_llm: bool, + tune_visual: bool, + select_layer: int, + reproject_vision: bool, + use_flash_attention: bool, + backbone_embedding_dim: int, + load_bf16: bool, + tune_top_llm_layers: int, + trainable_params_fp32: bool, + transformers_loading_kwargs: dict[str, Any] | None = None, + ): + super().__init__() + del model_name, reproject_vision, transformers_loading_kwargs + + if use_flash_attention: + try: + import flash_attn # noqa: F401 + + attn_implementation = "flash_attention_2" + except ImportError: + logger.warning( + "flash_attn is not installed. Falling back to sdpa attention. " + "Install flash-attn for better performance: pip install flash-attn" + ) + attn_implementation = "sdpa" + else: + attn_implementation = "sdpa" + + num_hidden_layers = select_layer if select_layer >= 0 else 28 + backbone_config = _make_qwen3_vl_2b_config( + backbone_embedding_dim=backbone_embedding_dim, + num_hidden_layers=num_hidden_layers, + attn_implementation=attn_implementation, + ) + self.model = Qwen3VLForConditionalGeneration(backbone_config).eval() + if load_bf16: + self.model.to(dtype=torch.bfloat16) + + target_layers = select_layer if select_layer >= 0 else len(self.model.model.language_model.layers) + while len(self.model.model.language_model.layers) > target_layers: + self.model.model.language_model.layers.pop(-1) + + self.set_trainable_parameters(tune_llm, tune_visual, tune_top_llm_layers) + if load_bf16 and trainable_params_fp32: + for name, param in self.named_parameters(): + if param.requires_grad: + param.data = param.data.to(torch.float32) + logger.debug("Casting trainable parameter %s to fp32", name) + + def set_trainable_parameters(self, tune_llm: bool, tune_visual: bool, tune_top_llm_layers: int) -> None: + self.tune_llm = tune_llm + self.tune_visual = tune_visual + for param in self.parameters(): + param.requires_grad = True + if not tune_llm: + self.model.model.language_model.requires_grad_(False) + if not tune_visual: + self.model.model.visual.requires_grad_(False) + + if tune_top_llm_layers > 0: + for layer in self.model.model.language_model.layers[-tune_top_llm_layers:]: + for param in layer.parameters(): + param.requires_grad = True + + logger.debug("Tune backbone llm: %s", self.tune_llm) + logger.debug("Tune backbone visual: %s", self.tune_visual) + for name, param in self.named_parameters(): + if param.requires_grad: + logger.debug("Backbone trainable parameter: %s", name) + if not any(param.requires_grad for param in self.parameters()): + logger.warning("No backbone trainable parameters found.") + + def set_frozen_modules_to_eval_mode(self) -> None: + if self.training: + if self.model.model.language_model and not self.tune_llm: + self.model.model.language_model.eval() + if self.model.model.visual and not self.tune_visual: + self.model.model.visual.eval() + + def prepare_input(self, batch: dict) -> BatchFeature: + return BatchFeature(data=batch) + + def forward(self, vl_input: BatchFeature) -> BatchFeature: + self.set_frozen_modules_to_eval_mode() + keys_to_use = [ + "input_ids", + "attention_mask", + "pixel_values", + "image_grid_thw", + "mm_token_type_ids", + ] + vl_input = {key: vl_input[key] for key in keys_to_use if key in vl_input} + # GR00T was trained against the **pre-final-RMSNorm** output of the last + # decoder layer (this was `outer.hidden_states[-1]` under HF 4.57.x semantics, + # where the @check_model_inputs tie-logic skipped overwriting because + # Qwen3VLCausalLMOutputWithPast has no `last_hidden_state` field). + # HF >=5.x changed the output_hidden_states tuple plumbing so that + # `hidden_states[-1]` is post-norm, which is a different tensor and breaks + # the trained action head. Capture the pre-norm tensor via a hook on the + # final RMSNorm regardless of transformers version. + _captured: list[torch.Tensor] = [] + + def _pre_norm_hook(_module, args, _out): + _captured.append(args[0]) + + norm = self.model.model.language_model.norm + handle = norm.register_forward_hook(_pre_norm_hook) + try: + self.model.model(**vl_input, return_dict=True) + finally: + handle.remove() + if not _captured: + raise RuntimeError("Failed to capture pre-norm hidden states from Qwen3-VL backbone") + backbone_features = _captured[-1] + image_mask = vl_input["input_ids"] == self.model.config.image_token_id + attention_mask = vl_input["attention_mask"] == 1 + return BatchFeature( + data={ + "backbone_features": backbone_features, + "backbone_attention_mask": attention_mask, + "image_mask": image_mask, + } + ) + + +def get_backbone_cls(config: Gr00tN1d7Config): + if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name: + return _Qwen3VLBackbone + else: + raise ValueError(f"Unsupported model name: {config.model_name}") + + +class Gr00tN1d7(PreTrainedModel): + """Gr00tN1d7: VLA model with Cosmos-Reason2-2B (Qwen3-VL) backbone.""" + + config_class = Gr00tN1d7Config + supports_gradient_checkpointing = True + _tp_plan = {} + + @property + def all_tied_weights_keys(self) -> dict[str, Any]: + return {} + + def __init__( + self, + config: Gr00tN1d7Config, + transformers_loading_kwargs: dict = {"trust_remote_code": True}, + ): + """ + Initialize Gr00tN1d7 model. + + Args: + config: Model configuration + transformers_loading_kwargs: Dict with transformers loading parameters: + - transformers_trust_remote_code: Whether to trust remote code when loading from HF Hub + - transformers_local_files_only: Whether to only use local files + - model_revision: Specific model revision to use + - transformers_cache_dir: Directory to cache downloaded models + - transformers_access_token: HuggingFace access token for gated models + + Note: During training, transformers parameters are passed from training config. + During inference (e.g., from_pretrained), defaults are used. + """ + super().__init__(config) + self.config = config + + backbone_cls = get_backbone_cls(config) + self.backbone = backbone_cls( + model_name=config.model_name, + tune_llm=config.tune_llm, + tune_visual=config.tune_visual, + select_layer=config.select_layer, + reproject_vision=config.reproject_vision, + use_flash_attention=config.use_flash_attention, + backbone_embedding_dim=config.backbone_embedding_dim, + load_bf16=config.load_bf16, + tune_top_llm_layers=config.tune_top_llm_layers, + trainable_params_fp32=config.backbone_trainable_params_fp32, + transformers_loading_kwargs=transformers_loading_kwargs, + ) + + # Initialize action head + self.action_head = Gr00tN1d7ActionHead(config) + from .processing_gr00t_n1d7 import Gr00tN1d7DataCollator + + self.collator = Gr00tN1d7DataCollator( + model_name=config.model_name, + model_type=config.backbone_model_type, + transformers_loading_kwargs=transformers_loading_kwargs, + ) + + def prepare_input(self, inputs: dict) -> tuple[BatchFeature, BatchFeature]: + """Prepare inputs for backbone and action head.""" + + # NOTE -- currently the eval code doesn't use collator, so we need to add it here + # this should ideally be fixed upstream + if "vlm_content" in inputs: + # Fix for n_envs > 1: Process all environments' VLM content, not just the first + vlm_content_list = inputs["vlm_content"] + # Ensure vlm_content_list is always a list for consistent processing + if not isinstance(vlm_content_list, list): + vlm_content_list = [vlm_content_list] + + # Process all VLM contents through the collator + prep = self.collator([{"vlm_content": vlm} for vlm in vlm_content_list])["inputs"] + inputs.pop("vlm_content") + inputs.update(prep) + + backbone_inputs = self.backbone.prepare_input(inputs) + action_inputs = self.action_head.prepare_input(inputs) + + backbone_inputs = backbone_inputs.to(device=self.device, dtype=self.dtype) + action_inputs = action_inputs.to(device=self.device, dtype=self.dtype) + + return backbone_inputs, action_inputs + + def forward(self, inputs: dict) -> BatchFeature: + """ + Forward pass through the complete model. + + Args: + inputs: Dictionary containing: + - Action inputs (state, action, embodiment_id, etc.) + + Returns: + BatchFeature containing loss and other outputs + """ + # Prepare inputs for backbone and action head + backbone_inputs, action_inputs = self.prepare_input(inputs) + backbone_outputs = self.backbone(backbone_inputs) + action_outputs = self.action_head(backbone_outputs, action_inputs) + + return action_outputs + + def get_action(self, inputs: dict, options: dict[str, Any] | None = None) -> BatchFeature: + """ + Generate actions using the complete model. + """ + # Prepare inputs for backbone and action head + backbone_inputs, action_inputs = self.prepare_input(inputs) + + # Forward through backbone + backbone_outputs = self.backbone(backbone_inputs) + action_outputs = self.action_head.get_action(backbone_outputs, action_inputs, options) + + return action_outputs + + @property + def device(self): + return next(iter(self.parameters())).device + + @property + def dtype(self): + return next(iter(self.parameters())).dtype + + +# Register the model with HuggingFace +AutoConfig.register("Gr00tN1d7", Gr00tN1d7Config) +AutoModel.register(Gr00tN1d7Config, Gr00tN1d7) diff --git a/vllm_omni/diffusion/models/gr00t/modeling/image_augmentations.py b/vllm_omni/diffusion/models/gr00t/modeling/image_augmentations.py new file mode 100755 index 00000000000..f4908564c00 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/modeling/image_augmentations.py @@ -0,0 +1,564 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from collections.abc import Sequence + +import albumentations as A +import cv2 +import numpy as np +import torch +import torchvision.transforms.v2 as transforms + + +def apply_with_replay(transform, images, masks=None, replay=None): + """ + Apply albumentations transforms to multiple images with replay functionality. + When masks are provided, mask-based transforms run per-frame before the main transform. + + Args: + transform: Albumentations ReplayCompose or Compose transform + images: List of PIL Images to transform + masks: Optional list of masks aligned with images (H, W) + replay: Optional replay data for consistent transforms. If None, creates new replay. + + Returns: + tuple: (transformed_tensors_list, replay_data) + - transformed_tensors_list: List of transformed torch tensors (C, H, W) as uint8 + - replay_data: Replay data for consistent transforms across images (None for regular Compose) + """ + transformed_tensors = [] + current_replay = replay + + # Check if transform supports replay (ReplayCompose) + has_replay = hasattr(transform, "replay") + + # Get mask-based transforms (applied per-frame, not replayed) + mask_transforms = getattr(transform, "mask_transforms", None) + + if masks is not None and len(masks) != len(images): + raise ValueError(f"Number of masks ({len(masks)}) must match number of images ({len(images)})") + + for idx, img in enumerate(images): + img_array = np.array(img) + mask_array = None if masks is None else np.array(masks[idx]) + if mask_array is not None and mask_array.dtype == np.bool_: + mask_array = mask_array.astype(np.uint8) + + # Apply mask-based transforms FIRST (per-frame, using current frame's mask) + if mask_transforms and mask_array is not None: + for mask_tf in mask_transforms: + result = mask_tf(image=img_array, mask=mask_array) + img_array = result["image"] + + if has_replay: + if current_replay is None: + # First image - create replay data + augmented_image = transform(image=img_array) + current_replay = augmented_image["replay"] + else: + # Subsequent images - use replay for consistent transforms + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + augmented_image = transform.replay(image=img_array, saved_augmentations=current_replay) + else: + # Regular Compose transform - no replay functionality + augmented_image = transform(image=img_array) + + img_array = augmented_image["image"] + # Convert to uint8 if needed (albumentations may return float32 in [0,1]) + if img_array.dtype == np.float32: + img_array = (img_array * 255).astype(np.uint8) + elif img_array.dtype != np.uint8: + raise ValueError(f"Unexpected data type: {img_array.dtype}") + + # Convert to torch tensor (C, H, W) as uint8 + img_tensor = torch.from_numpy(img_array).permute(2, 0, 1) + transformed_tensors.append(img_tensor) + + return transformed_tensors, current_replay + + +class MaskedColorTransform(A.ImageOnlyTransform): + """Apply random tint to specific mask regions. + + Args: + target_mask_values: List of mask values to apply the transform to + alpha_range: (min, max) for random_tint overlay intensity + p: Probability of applying the transform + """ + + def __init__( + self, + target_mask_values: Sequence[int], + alpha_range: tuple[float, float] = (0.3, 1.0), + p: float = 0.5, + always_apply: bool | None = None, + ): + super().__init__(p=p, always_apply=always_apply) + self.target_mask_values = list(target_mask_values) + self.alpha_range = alpha_range + + def apply(self, img: np.ndarray, mask: np.ndarray = None, **params) -> np.ndarray: + if mask is None: + return img + + region_mask = np.zeros(mask.shape[:2], dtype=bool) + for val in self.target_mask_values: + region_mask |= mask == val + + if not region_mask.any(): + return img + + # Random color + random_color = np.random.randint(0, 256, size=3).astype(np.float32) + result = img.copy().astype(np.float32) + + # Random tint: semi-transparent overlay + alpha = np.random.uniform(self.alpha_range[0], self.alpha_range[1]) + for c in range(3): + result[region_mask, c] = result[region_mask, c] * (1 - alpha) + random_color[c] * alpha + + return np.clip(result, 0, 255).astype(np.uint8) + + def get_params_dependent_on_data(self, params, data) -> dict: + return {"mask": data.get("mask")} + + def get_transform_init_args_names(self) -> tuple[str, ...]: + return ("target_mask_values", "alpha_range") + + +class BackgroundNoiseTransform(A.ImageOnlyTransform): + """Replace specified mask regions with random noise. + + This transform replaces pixels where mask value matches target_mask_values with random RGB noise, + useful for domain randomization in sim-to-real transfer. + + Args: + p: Probability of applying the transform + target_mask_values: Mask values to replace with noise (default: [0]) + """ + + def __init__( + self, + p: float = 1.0, + target_mask_values: Sequence[int] | None = None, + always_apply: bool | None = None, + ): + super().__init__(p=p, always_apply=always_apply) + self.target_mask_values = [0] if target_mask_values is None else list(target_mask_values) + + def apply(self, img: np.ndarray, mask: np.ndarray = None, **params) -> np.ndarray: + if mask is None: + return img + + result = img.copy() + mask_2d = mask[..., 0] if mask.ndim == 3 else mask + background = np.isin(mask_2d, self.target_mask_values) + + if background.any(): + noise = np.random.randint(0, 256, size=result.shape, dtype=np.uint8) + result[background] = noise[background] + + return result + + def get_params_dependent_on_data(self, params, data) -> dict: + return {"mask": data.get("mask")} + + def get_transform_init_args_names(self) -> tuple[str, ...]: + return ("target_mask_values",) + + +class FractionalRandomCrop(A.DualTransform): + """Crop a random part of the input based on fractions while maintaining aspect ratio. + + Args: + crop_fraction: Fraction of the image to crop (0.0 to 1.0). The crop will maintain + the original aspect ratio and be this fraction of the original area. + p: probability of applying the transform. Default: 1.0 + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def __init__( + self, + crop_fraction: float = 0.9, + p: float = 1.0, + always_apply: bool | None = None, + ): + super().__init__(p=p, always_apply=always_apply) + if not 0.0 < crop_fraction <= 1.0: + raise ValueError("crop_fraction must be between 0.0 and 1.0") + self.crop_fraction = crop_fraction + + def apply(self, img: np.ndarray, crop_coords: tuple[int, int, int, int], **params) -> np.ndarray: + x_min, y_min, x_max, y_max = crop_coords + return img[y_min:y_max, x_min:x_max] + + def apply_to_bboxes(self, bboxes: np.ndarray, crop_coords: tuple[int, int, int, int], **params) -> np.ndarray: + return A.augmentations.crops.functional.crop_bboxes_by_coords(bboxes, crop_coords, params["shape"]) + + def apply_to_keypoints(self, keypoints: np.ndarray, crop_coords: tuple[int, int, int, int], **params) -> np.ndarray: + return A.augmentations.crops.functional.crop_keypoints_by_coords(keypoints, crop_coords) + + def get_params_dependent_on_data(self, params, data) -> dict[str, tuple[int, int, int, int]]: + image_shape = params["shape"][:2] + height, width = image_shape + + # Calculate crop dimensions with linear scaling + crop_height = int(height * self.crop_fraction) + crop_width = int(width * self.crop_fraction) + + # Ensure minimum size of 1x1 + crop_height = max(1, crop_height) + crop_width = max(1, crop_width) + # Random position for crop + max_y = height - crop_height + max_x = width - crop_width + + y_min = np.random.randint(0, max_y + 1) if max_y > 0 else 0 + x_min = np.random.randint(0, max_x + 1) if max_x > 0 else 0 + + crop_coords = (x_min, y_min, x_min + crop_width, y_min + crop_height) + return {"crop_coords": crop_coords} + + def get_transform_init_args_names(self) -> tuple[str, ...]: + return ("crop_fraction",) + + +class FractionalCenterCrop(A.DualTransform): + """Crop the center part of the input based on fractions while maintaining aspect ratio. + + Args: + crop_fraction: Fraction of the image to crop (0.0 to 1.0). The crop will maintain + the original aspect ratio and be this fraction of the original area. + p: probability of applying the transform. Default: 1.0 + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def __init__( + self, + crop_fraction: float = 0.9, + p: float = 1.0, + always_apply: bool | None = None, + ): + super().__init__(p=p, always_apply=always_apply) + if not 0.0 < crop_fraction <= 1.0: + raise ValueError("crop_fraction must be between 0.0 and 1.0") + self.crop_fraction = crop_fraction + + def apply(self, img: np.ndarray, crop_coords: tuple[int, int, int, int], **params) -> np.ndarray: + x_min, y_min, x_max, y_max = crop_coords + return img[y_min:y_max, x_min:x_max] + + def apply_to_bboxes(self, bboxes: np.ndarray, crop_coords: tuple[int, int, int, int], **params) -> np.ndarray: + return A.augmentations.crops.functional.crop_bboxes_by_coords(bboxes, crop_coords, params["shape"]) + + def apply_to_keypoints(self, keypoints: np.ndarray, crop_coords: tuple[int, int, int, int], **params) -> np.ndarray: + return A.augmentations.crops.functional.crop_keypoints_by_coords(keypoints, crop_coords) + + def get_params_dependent_on_data(self, params, data) -> dict[str, tuple[int, int, int, int]]: + image_shape = params["shape"][:2] + height, width = image_shape + + # Calculate crop dimensions with linear scaling + crop_height = int(height * self.crop_fraction) + crop_width = int(width * self.crop_fraction) + + # Ensure minimum size of 1x1 + crop_height = max(1, crop_height) + crop_width = max(1, crop_width) + + # Center the crop + y_min = (height - crop_height) // 2 + x_min = (width - crop_width) // 2 + + crop_coords = (x_min, y_min, x_min + crop_width, y_min + crop_height) + return {"crop_coords": crop_coords} + + def get_transform_init_args_names(self) -> tuple[str, ...]: + return ("crop_fraction",) + + +class LetterBoxPad(A.DualTransform): + """Pad non-square images to square by adding black bars (letterboxing). + + This is the albumentations equivalent of LetterBoxTransform (torchvision). + Ensures all images have the same spatial dimensions after padding, + regardless of their original aspect ratio. + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__(self, p: float = 1.0, always_apply: bool | None = None): + super().__init__(p=p, always_apply=always_apply) + + def apply( + self, + img: np.ndarray, + pad_top: int = 0, + pad_bottom: int = 0, + pad_left: int = 0, + pad_right: int = 0, + **params, + ) -> np.ndarray: + if pad_top == 0 and pad_bottom == 0 and pad_left == 0 and pad_right == 0: + return img + return cv2.copyMakeBorder(img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0) + + def get_params_dependent_on_data(self, params, data) -> dict[str, int]: + h, w = params["shape"][:2] + if h == w: + return {"pad_top": 0, "pad_bottom": 0, "pad_left": 0, "pad_right": 0} + max_dim = max(h, w) + pad_h = max_dim - h + pad_w = max_dim - w + return { + "pad_top": pad_h // 2, + "pad_bottom": pad_h - pad_h // 2, + "pad_left": pad_w // 2, + "pad_right": pad_w - pad_w // 2, + } + + def get_transform_init_args_names(self) -> tuple[str, ...]: + return () + + +def build_image_transformations_albumentations( + image_target_size, + image_crop_size, + random_rotation_angle, + color_jitter_params, + shortest_image_edge, + crop_fraction, + extra_augmentation_config: dict | None = None, +): + """ + Build albumentations-based image transformations equivalent to the torchvision version. + + Args: + image_target_size: Target size for resizing (list of [height, width]) + image_crop_size: Size for cropping (list of [height, width]) + random_rotation_angle: Maximum rotation angle in degrees (0 for no rotation) + color_jitter_params: Dictionary with color jitter parameters (brightness, contrast, saturation, hue) + shortest_image_edge: Shortest edge size for resizing + crop_fraction: Fraction of image to crop + extra_augmentation_config: Optional dict for additional augmentations. Supported keys: + - "background_noise_transforms": list of dicts, each with: + - "target_mask_values": list of int (e.g., [0]) + - "p": float (probability of applying transform) + - "masked_region_transforms": list of dicts, each with: + - "target_mask_values": list of int (e.g., [4] or [5]) + - "p": float (probability of applying transform) + - "alpha_range": [min, max] for random_tint mode intensity + + Returns: + tuple: (train_transform, eval_transform) - raw albumentations transforms + """ + + if crop_fraction is None: + fraction_to_use = image_crop_size[0] / image_target_size[0] + else: + fraction_to_use = crop_fraction + + if shortest_image_edge is None: + max_size = image_target_size[0] + else: + max_size = shortest_image_edge + + extra_augmentation_config = extra_augmentation_config or {} + + # Training transforms (using ReplayCompose for consistent augmentation across views) + # Use SmallestMaxSize to preserve aspect ratios, with INTER_AREA for antialiasing + train_transform_list = [ + LetterBoxPad(), + A.SmallestMaxSize(max_size=max_size, interpolation=cv2.INTER_AREA), + FractionalRandomCrop(crop_fraction=fraction_to_use), + A.SmallestMaxSize(max_size=max_size, interpolation=cv2.INTER_AREA), + ] + + if random_rotation_angle is not None and random_rotation_angle != 0: + train_transform_list.append(A.Rotate(limit=random_rotation_angle, p=1.0)) + + if color_jitter_params is not None: + # Map torchvision ColorJitter parameters to albumentations ColorJitter + # Note: albumentations uses different parameter names and ranges + train_transform_list.append( + A.ColorJitter( + brightness=color_jitter_params.get("brightness", 0.0), + contrast=color_jitter_params.get("contrast", 0.0), + saturation=color_jitter_params.get("saturation", 0.0), + hue=color_jitter_params.get("hue", 0.0), + p=1.0, + ) + ) + + train_transform = A.ReplayCompose(train_transform_list, p=1.0) + + # === Mask-based augmentations (applied per-frame, NOT in ReplayCompose) === + # These transforms depend on per-frame mask data and must not be replayed + # to ensure each frame uses its own mask + mask_transforms = [] + + # Background noise on mask regions + for noise_cfg in extra_augmentation_config.get("background_noise_transforms", []): + target_mask_values = noise_cfg.get("target_mask_values", [0]) + p = noise_cfg.get("p", 1.0) + mask_transforms.append( + BackgroundNoiseTransform( + p=float(p), + target_mask_values=target_mask_values, + ) + ) + + # Masked region transforms + for transform_cfg in extra_augmentation_config.get("masked_region_transforms", []): + target_mask_values = transform_cfg.get("target_mask_values", []) + p = transform_cfg.get("p", 0.5) + alpha_range = tuple(transform_cfg.get("alpha_range", [0.3, 1.0])) + + mask_transforms.append( + MaskedColorTransform( + target_mask_values=target_mask_values, + alpha_range=alpha_range, + p=p, + ) + ) + + # Attach mask transforms to the main transform for use in apply_with_replay + train_transform.mask_transforms = mask_transforms if mask_transforms else None + + # Evaluation transforms (deterministic, no extra augmentations) + # Use SmallestMaxSize to preserve aspect ratios, with INTER_AREA for antialiasing + eval_transform = A.Compose( + [ + LetterBoxPad(), + A.SmallestMaxSize(max_size=max_size, interpolation=cv2.INTER_AREA), + FractionalCenterCrop(crop_fraction=fraction_to_use), + A.SmallestMaxSize(max_size=max_size, interpolation=cv2.INTER_AREA), + ] + ) + + return train_transform, eval_transform + + +class LetterBoxTransform: + """Custom transform to pad non-square images to square by adding black bars. + + Works with any tensor shape where the last 3 dimensions are (C, H, W). + Leading dimensions (batch, time, views, etc.) are preserved. + """ + + def __call__(self, img: torch.Tensor) -> torch.Tensor: + """ + Pad image to square dimensions by adding black bars to the smaller dimension. + + Args: + img: Image tensor of shape (..., C, H, W) where ... can be any leading dimensions + Examples: (C, H, W), (B, C, H, W), (B, T*V, C, H, W) + + Returns: + Padded image tensor of shape (..., C, max(H,W), max(H,W)) + """ + # Get the height and width from the last 2 dimensions + *leading_dims, c, h, w = img.shape + + if h == w: + return img + + # Calculate padding needed + max_dim = max(h, w) + pad_h = max_dim - h + pad_w = max_dim - w + + # Add padding to center the image (divide padding equally on both sides) + pad_top = pad_h // 2 + pad_bottom = pad_h - pad_top + pad_left = pad_w // 2 + pad_right = pad_w - pad_left + + # If we have leading dimensions, we need to flatten them, pad, then unflatten + if leading_dims: + # Reshape to (batch, C, H, W) where batch includes all leading dimensions + batch_size = math.prod(leading_dims) + img_reshaped = img.reshape(batch_size, c, h, w) + + # Apply padding to each image in the batch + # torchvision padding format: (left, right, top, bottom) + padded_img = transforms.functional.pad( + img_reshaped, padding=[pad_left, pad_top, pad_right, pad_bottom], fill=0 + ) + + # Reshape back to original leading dimensions + output_shape = leading_dims + [c, max_dim, max_dim] + padded_img = padded_img.reshape(output_shape) + else: + # Simple case: just (C, H, W) + padded_img = transforms.functional.pad(img, padding=[pad_left, pad_top, pad_right, pad_bottom], fill=0) + + return padded_img + + +def build_image_transformations(image_target_size, image_crop_size, random_rotation_angle, color_jitter_params): + """ + Build torchvision-based image transformations. + + Args: + image_target_size: Target size for resizing (list of [height, width]) + image_crop_size: Size for cropping (list of [height, width]) + random_rotation_angle: Maximum rotation angle in degrees (0 for no rotation) + color_jitter_params: Dictionary with color jitter parameters (brightness, contrast, saturation, hue) + + Returns: + tuple: (train_transform, eval_transform) - torchvision transforms + """ + transform_list = [ + transforms.ToImage(), + LetterBoxTransform(), + # transforms.ToDtype(torch.get_default_dtype(), scale=True), + transforms.Resize(size=image_target_size), + transforms.RandomCrop(size=image_crop_size), + transforms.Resize(size=image_target_size), + ] + if random_rotation_angle is not None and random_rotation_angle != 0: + transform_list.append(transforms.RandomRotation(degrees=[-random_rotation_angle, random_rotation_angle])) + if color_jitter_params is not None: + transform_list.append(transforms.ColorJitter(**color_jitter_params)) + train_image_transform = transforms.Compose(transform_list) + eval_image_transform = transforms.Compose( + [ + transforms.ToImage(), + # transforms.ToDtype(torch.get_default_dtype(), scale=True), + LetterBoxTransform(), + transforms.Resize(size=image_target_size), + transforms.CenterCrop(size=image_crop_size), + transforms.Resize(size=image_target_size), + ] + ) + return train_image_transform, eval_image_transform diff --git a/vllm_omni/diffusion/models/gr00t/modeling/modules/__init__.py b/vllm_omni/diffusion/models/gr00t/modeling/modules/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/modeling/modules/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py b/vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py new file mode 100755 index 00000000000..3aa25e81c52 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py @@ -0,0 +1,478 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from contextlib import nullcontext + +import torch +import torch.nn.functional as F +from diffusers import ConfigMixin, ModelMixin +from diffusers.configuration_utils import register_to_config +from diffusers.models.attention import Attention, FeedForward +from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps +from torch import nn + + +def _is_spark_sm121() -> bool: + if not torch.cuda.is_available(): + return False + + major, minor = torch.cuda.get_device_capability() + return (major, minor) == (12, 1) + + +def _should_force_math_sdpa() -> bool: + override = os.environ.get("GR00T_DIT_SDPA_MODE") + if override == "math": + return True + if override == "default": + return False + + return _is_spark_sm121() + + +def _sdpa_context(): + # Spark (sm121) currently hits noisy/broken PyTorch mem-efficient SDPA kernel dispatch. + # Force the safe math backend there; on every other platform this returns a no-op context. + if not _should_force_math_sdpa(): + return nullcontext() + + return torch.backends.cuda.sdp_kernel( + enable_flash=False, + enable_math=True, + enable_mem_efficient=False, + enable_cudnn=False, + ) + + +class TimestepEncoder(nn.Module): + def __init__(self, embedding_dim, compute_dtype=torch.float32): + super().__init__() + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timesteps): + dtype = next(self.parameters()).dtype + timesteps_proj = self.time_proj(timesteps).to(dtype) + timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D) + return timesteps_emb + + +class AdaLayerNorm(nn.Module): + def __init__( + self, + embedding_dim: int, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + chunk_dim: int = 0, + ): + super().__init__() + self.chunk_dim = chunk_dim + output_dim = embedding_dim * 2 + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) + + def forward( + self, + x: torch.Tensor, + temb: torch.Tensor | None = None, + ) -> torch.Tensor: + temb = self.linear(self.silu(temb)) + scale, shift = temb.chunk(2, dim=1) + x = self.norm(x) * (1 + scale[:, None]) + shift[:, None] + return x + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: int | None = None, + activation_fn: str = "geglu", + attention_bias: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', + # 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_type: str = "layer_norm", + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: str | None = None, + num_positional_embeddings: int | None = None, + ff_inner_dim: int | None = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.norm_type = norm_type + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_position_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + if final_dropout: + self.final_dropout = nn.Dropout(dropout) + else: + self.final_dropout = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + temb: torch.LongTensor | None = None, + ) -> torch.Tensor: + # 0. Self-Attention + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, temb) + else: + norm_hidden_states = self.norm1(hidden_states) + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + with _sdpa_context(): + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=(encoder_attention_mask if encoder_hidden_states is not None else attention_mask), + ) + if self.final_dropout: + attn_output = self.final_dropout(attn_output) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + return hidden_states + + +class DiT(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 8, + attention_head_dim: int = 64, + output_dim: int = 26, + num_layers: int = 12, + dropout: float = 0.1, + attention_bias: bool = True, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: int | None = 1000, + upcast_attention: bool = False, + norm_type: str = "ada_norm", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + max_num_positional_embeddings: int = 512, + compute_dtype=torch.float32, + final_dropout: bool = True, + positional_embeddings: str | None = "sinusoidal", + interleave_self_attention=False, + cross_attention_dim: int | None = None, + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.gradient_checkpointing = False + + # Timestep encoder + self.timestep_encoder = TimestepEncoder(embedding_dim=self.inner_dim, compute_dtype=self.compute_dtype) + + all_blocks = [] + for idx in range(self.config.num_layers): + use_self_attn = idx % 2 == 1 and interleave_self_attention + curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None + + all_blocks += [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + activation_fn=self.config.activation_fn, + attention_bias=self.config.attention_bias, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + positional_embeddings=positional_embeddings, + num_positional_embeddings=self.config.max_num_positional_embeddings, + final_dropout=final_dropout, + cross_attention_dim=curr_cross_attention_dim, + ) + ] + self.transformer_blocks = nn.ModuleList(all_blocks) + + # Output blocks + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) + self.proj_out_2 = nn.Linear(self.inner_dim, self.output_dim) + print( + "Total number of DiT parameters: ", + sum(p.numel() for p in self.parameters() if p.requires_grad), + ) + + def forward( + self, + hidden_states: torch.Tensor, # Shape: (B, T, D) + encoder_hidden_states: torch.Tensor, # Shape: (B, S, D) + timestep: torch.LongTensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + return_all_hidden_states: bool = False, + ): + # Encode timesteps + temb = self.timestep_encoder(timestep) + + # Process through transformer blocks - single pass through the blocks + hidden_states = hidden_states.contiguous() + encoder_hidden_states = encoder_hidden_states.contiguous() + + all_hidden_states = [hidden_states] + + # Process through transformer blocks + for idx, block in enumerate(self.transformer_blocks): + if idx % 2 == 1 and self.config.interleave_self_attention: + hidden_states = block( + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + temb=temb, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=None, + temb=temb, + ) + all_hidden_states.append(hidden_states) + + # Output processing + conditioning = temb + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + if return_all_hidden_states: + return self.proj_out_2(hidden_states), all_hidden_states + else: + return self.proj_out_2(hidden_states) + + +class AlternateVLDiT(DiT): + """ + Alternate Vision-Language DiT that separates image and non-image tokens + during cross-attention processing. + """ + + def __init__(self, *args, attend_text_every_n_blocks: int = 2, **kwargs): + super().__init__(*args, **kwargs) + self.attend_text_every_n_blocks = attend_text_every_n_blocks + + def forward( + self, + hidden_states: torch.Tensor, # Shape: (B, T, D) + encoder_hidden_states: torch.Tensor, # Shape: (B, S, D) + timestep: torch.LongTensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + return_all_hidden_states: bool = False, + image_mask: torch.Tensor | None = None, + backbone_attention_mask: torch.Tensor | None = None, + ): + assert image_mask is not None, "Image mask is required" + + # Encode timesteps + temb = self.timestep_encoder(timestep) + + # Process through transformer blocks - single pass through the blocks + hidden_states = hidden_states.contiguous() + encoder_hidden_states = encoder_hidden_states.contiguous() + + # Create attention masks for image and non-image tokens + # image_mask shape: (B, S) where True indicates image tokens + # For attention, we need to invert: False means "don't attend to this token" + + image_attention_mask = image_mask & backbone_attention_mask + non_image_attention_mask = (~image_mask) & backbone_attention_mask + + all_hidden_states = [hidden_states] + assert self.config.interleave_self_attention, "Interleave self attention must be enabled" + + # Process through transformer blocks + for idx, block in enumerate(self.transformer_blocks): + if idx % 2 == 1: + # Self-attention blocks + hidden_states = block( + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + temb=temb, + ) + else: + # Cross-attention blocks - alternate between non-image and image tokens + if idx % (2 * self.attend_text_every_n_blocks) == 0: + # Attend to non-image tokens + curr_encoder_attention_mask = non_image_attention_mask + else: + # Attend to image tokens + curr_encoder_attention_mask = image_attention_mask + + hidden_states = block( + hidden_states, + attention_mask=None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=curr_encoder_attention_mask, + temb=temb, + ) + all_hidden_states.append(hidden_states) + + # Output processing + conditioning = temb + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + if return_all_hidden_states: + return self.proj_out_2(hidden_states), all_hidden_states + else: + return self.proj_out_2(hidden_states) + + +class SelfAttentionTransformer(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 8, + attention_head_dim: int = 64, + output_dim: int = 26, + num_layers: int = 12, + dropout: float = 0.1, + attention_bias: bool = True, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: int | None = 1000, + upcast_attention: bool = False, + max_num_positional_embeddings: int = 512, + compute_dtype=torch.float32, + final_dropout: bool = True, + positional_embeddings: str | None = "sinusoidal", + interleave_self_attention=False, + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.gradient_checkpointing = False + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + activation_fn=self.config.activation_fn, + attention_bias=self.config.attention_bias, + upcast_attention=self.config.upcast_attention, + positional_embeddings=positional_embeddings, + num_positional_embeddings=self.config.max_num_positional_embeddings, + final_dropout=final_dropout, + ) + for _ in range(self.config.num_layers) + ] + ) + print( + "Total number of SelfAttentionTransformer parameters: ", + sum(p.numel() for p in self.parameters() if p.requires_grad), + ) + + def forward( + self, + hidden_states: torch.Tensor, # Shape: (B, T, D) + return_all_hidden_states: bool = False, + ): + # Process through transformer blocks - single pass through the blocks + hidden_states = hidden_states.contiguous() + all_hidden_states = [hidden_states] + + # Process through transformer blocks + for idx, block in enumerate(self.transformer_blocks): + hidden_states = block(hidden_states) + all_hidden_states.append(hidden_states) + + if return_all_hidden_states: + return hidden_states, all_hidden_states + else: + return hidden_states diff --git a/vllm_omni/diffusion/models/gr00t/modeling/modules/embodiment_conditioned_mlp.py b/vllm_omni/diffusion/models/gr00t/modeling/modules/embodiment_conditioned_mlp.py new file mode 100644 index 00000000000..df9e3cde38d --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/modeling/modules/embodiment_conditioned_mlp.py @@ -0,0 +1,228 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn.functional as F +from torch import nn + + +def swish(x): + """Swish activation function.""" + return x * torch.sigmoid(x) + + +class SinusoidalPositionalEncoding(nn.Module): + """ + Produces a sinusoidal encoding of shape (B, T, w) + given timesteps of shape (B, T). + """ + + def __init__(self, embedding_dim): + super().__init__() + self.embedding_dim = embedding_dim + + def forward(self, timesteps): + # timesteps: shape (B, T) + # We'll compute sin/cos frequencies across dim T + timesteps = timesteps.float() # ensure float + + B, T = timesteps.shape + device = timesteps.device + + half_dim = self.embedding_dim // 2 + # typical log space frequencies for sinusoidal encoding + exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (math.log(10000.0) / half_dim) + # Expand timesteps to (B, T, 1) then multiply + freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim) + + sin = torch.sin(freqs) + cos = torch.cos(freqs) + enc = torch.cat([sin, cos], dim=-1) # (B, T, w) + + return enc + + +class CategorySpecificLinear(nn.Module): + """Linear layer with category-specific weights and biases for multi-embodiment support.""" + + def __init__(self, num_categories, input_dim, hidden_dim): + super().__init__() + self.num_categories = num_categories + # For each category, we have separate weights and biases. + self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim)) + self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim)) + + def forward(self, x, cat_ids): + """ + Args: + x: [B, T, input_dim] input tensor + cat_ids: [B] category/embodiment IDs + Returns: + [B, T, hidden_dim] output tensor + """ + selected_W = self.W[cat_ids] + selected_b = self.b[cat_ids] + return torch.bmm(x, selected_W) + selected_b.unsqueeze(1) + + def expand_action_dimension(self, old_action_dim, new_action_dim, expand_input=False, expand_output=False): + """ + Safely expand action dimension with explicit targeting. + + Args: + old_action_dim: Original action dimension + new_action_dim: New (larger) action dimension + expand_input: Whether to expand input dimension (dim=1) + expand_output: Whether to expand output dimension (dim=2) + """ + if new_action_dim <= old_action_dim: + raise ValueError(f"New action dim {new_action_dim} must be larger than old action dim {old_action_dim}") + + # Expand input dimension (dim=1) only if explicitly requested AND dimensions match + if expand_input and self.W.shape[1] == old_action_dim: + repeat_times = new_action_dim // old_action_dim + remainder = new_action_dim % old_action_dim + + new_W_parts = [self.W] * repeat_times + if remainder > 0: + new_W_parts.append(self.W[:, :remainder, :]) + + new_W = torch.cat(new_W_parts, dim=1) + self.W = nn.Parameter(new_W) + + # Expand output dimension (dim=2) only if explicitly requested AND dimensions match + if expand_output and self.W.shape[2] == old_action_dim: + repeat_times = new_action_dim // old_action_dim + remainder = new_action_dim % old_action_dim + + new_W_parts = [self.W] * repeat_times + if remainder > 0: + new_W_parts.append(self.W[:, :, :remainder]) + + new_W = torch.cat(new_W_parts, dim=2) + self.W = nn.Parameter(new_W) + + # Expand bias for output dimension + if self.b.shape[1] == old_action_dim: + new_b_parts = [self.b] * repeat_times + if remainder > 0: + new_b_parts.append(self.b[:, :remainder]) + + new_b = torch.cat(new_b_parts, dim=1) + self.b = nn.Parameter(new_b) + + +class SmallMLP(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim): + super().__init__() + self.layer1 = nn.Linear(input_dim, hidden_dim) + self.layer2 = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + hidden = F.relu(self.layer1(x)) + return self.layer2(hidden) + + +class CategorySpecificMLP(nn.Module): + """Two-layer MLP with category-specific weights for multi-embodiment support.""" + + def __init__(self, num_categories, input_dim, hidden_dim, output_dim): + super().__init__() + self.num_categories = num_categories + self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim) + self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim) + + def forward(self, x, cat_ids): + """ + Args: + x: [B, T, input_dim] input tensor + cat_ids: [B] category/embodiment IDs + Returns: + [B, T, output_dim] output tensor + """ + hidden = F.relu(self.layer1(x, cat_ids)) + return self.layer2(hidden, cat_ids) + + def expand_action_dimension(self, old_action_dim, new_action_dim): + """ + Expand action dimension by copying weights from existing dimensions. + + Args: + old_action_dim: Original action dimension + new_action_dim: New (larger) action dimension + """ + # self.layer1 does not take action_dim as input, so no expansion needed + self.layer2.expand_action_dimension(old_action_dim, new_action_dim, expand_input=False, expand_output=True) + + +class MultiEmbodimentActionEncoder(nn.Module): + """Action encoder with multi-embodiment support and sinusoidal positional encoding.""" + + def __init__(self, action_dim, hidden_size, num_embodiments): + super().__init__() + self.hidden_size = hidden_size + self.num_embodiments = num_embodiments + + # W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w} + self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w) + self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w) + self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w) + self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) + + def forward(self, actions, timesteps, cat_ids): + """ + Args: + actions: [B, T, action_dim] action tensor + timesteps: [B,] timesteps - a single scalar per batch item + cat_ids: [B,] category/embodiment IDs + Returns: + [B, T, hidden_size] encoded action features + """ + B, T, _ = actions.shape + + # 1) Expand each batch's single scalar time 'tau' across all T steps + # so that shape => (B, T) + # e.g. if timesteps is (B,), replicate across T + if timesteps.dim() == 1 and timesteps.shape[0] == B: + # shape (B,) => (B,T) + timesteps = timesteps.unsqueeze(1).expand(-1, T) + else: + raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.") + + # 2) Standard action MLP step for shape => (B, T, w) + a_emb = self.W1(actions, cat_ids) + + # 3) Get the sinusoidal encoding (B, T, w) + tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype) + + # 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish + x = torch.cat([a_emb, tau_emb], dim=-1) + x = swish(self.W2(x, cat_ids)) + + # 5) Finally W3 => (B, T, w) + x = self.W3(x, cat_ids) + return x + + def expand_action_dimension(self, old_action_dim, new_action_dim): + """ + Expand action dimension by copying weights from existing dimensions. + + Args: + old_action_dim: Original action dimension + new_action_dim: New (larger) action dimension + """ + # Only W1 takes action_dim as input, so only expand its input dimension + self.W1.expand_action_dimension(old_action_dim, new_action_dim, expand_input=True, expand_output=False) diff --git a/vllm_omni/diffusion/models/gr00t/modeling/modules/flowmatching_modules.py b/vllm_omni/diffusion/models/gr00t/modeling/modules/flowmatching_modules.py new file mode 100644 index 00000000000..fcc59ab11c3 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/modeling/modules/flowmatching_modules.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn.functional as F +from torch import nn + + +def swish(x): + return x * torch.sigmoid(x) + + +class SinusoidalPositionalEncoding(nn.Module): + """ + Produces a sinusoidal encoding of shape (B, T, w) + given timesteps of shape (B, T). + """ + + def __init__(self, embedding_dim): + super().__init__() + self.embedding_dim = embedding_dim + + def forward(self, timesteps): + # timesteps: shape (B, T) + # We'll compute sin/cos frequencies across dim T + timesteps = timesteps.float() # ensure float + + B, T = timesteps.shape + device = timesteps.device + + half_dim = self.embedding_dim // 2 + # typical log space frequencies for sinusoidal encoding + exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (math.log(10000.0) / half_dim) + # Expand timesteps to (B, T, 1) then multiply + freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim) + + sin = torch.sin(freqs) + cos = torch.cos(freqs) + enc = torch.cat([sin, cos], dim=-1) # (B, T, w) + + return enc + + +class SmallMLP(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim): + super().__init__() + self.layer1 = nn.Linear(input_dim, hidden_dim) + self.layer2 = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + hidden = F.relu(self.layer1(x)) + return self.layer2(hidden) + + +class ActionEncoder(nn.Module): + def __init__(self, action_dim, hidden_size): + super().__init__() + self.hidden_size = hidden_size + + # W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w} + self.W1 = nn.Linear(action_dim, hidden_size) # (d -> w) + self.W2 = nn.Linear(2 * hidden_size, hidden_size) # (2w -> w) + self.W3 = nn.Linear(hidden_size, hidden_size) # (w -> w) + + self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) + + def forward(self, actions, timesteps): + """ + actions: shape (B, T, action_dim) + timesteps: shape (B,) -- a single scalar per batch item + returns: shape (B, T, hidden_size) + """ + B, T, _ = actions.shape + + # 1) Expand each batch's single scalar time 'tau' across all T steps + # so that shape => (B, T) + # e.g. if timesteps is (B,), replicate across T + if timesteps.dim() == 1 and timesteps.shape[0] == B: + # shape (B,) => (B,T) + timesteps = timesteps.unsqueeze(1).expand(-1, T) + else: + raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.") + + # 2) Standard action MLP step for shape => (B, T, w) + a_emb = self.W1(actions) + + # 3) Get the sinusoidal encoding (B, T, w) + tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype) + + # 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish + x = torch.cat([a_emb, tau_emb], dim=-1) + x = swish(self.W2(x)) + + # 5) Finally W3 => (B, T, w) + x = self.W3(x) + + return x diff --git a/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py new file mode 100644 index 00000000000..b53fd3dd520 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py @@ -0,0 +1,762 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import random +import re +import warnings +from copy import deepcopy +from pathlib import Path +from typing import Any + +import albumentations as A +import numpy as np +import torch +import torchvision.transforms.v2 as transforms +from PIL import Image +from transformers import AutoProcessor +from transformers.feature_extraction_utils import BatchFeature +from transformers.utils import cached_file +from vllm.logger import init_logger + +from vllm_omni.diffusion.models.gr00t.configs.embodiment.embodiment_configs import ModalityConfig +from vllm_omni.diffusion.models.gr00t.dataio.embodiment_tags import EmbodimentTag +from vllm_omni.diffusion.models.gr00t.dataio.interfaces import BaseProcessor +from vllm_omni.diffusion.models.gr00t.dataio.state_action.state_action_processor import StateActionProcessor +from vllm_omni.diffusion.models.gr00t.dataio.utils import parse_modality_configs, to_json_serializable + +from .image_augmentations import ( + apply_with_replay, + build_image_transformations, + build_image_transformations_albumentations, +) + +try: + from transformers import Qwen3VLProcessor +except ImportError: + Qwen3VLProcessor = None + +logger = init_logger(__name__) + +# Suppress protobuf deprecation warnings +warnings.filterwarnings("ignore", category=DeprecationWarning, module="google.protobuf") + +### Mapping from embodiment tag to projector index. +EMBODIMENT_TAG_TO_PROJECTOR_INDEX = { + ##### Pretrain embodiment ids (in base model) ##### + "oxe_droid_relative_eef_relative_joint": 24, + "xdof_relative_eef_relative_joint": 27, + "xdof_relative_eef_relative_joint_subtask": 27, + "real_g1_relative_eef_relative_joints": 25, + "real_r1_pro_sharpa_relative_eef": 26, + "real_r1_pro_sharpa_relative_eef_human": 26, + "real_r1_pro_sharpa_relative_eef_maxinsights": 26, + "real_r1_pro_sharpa_relative_eef_mecka": 26, + ##### Posttrain embodiment ids ##### + "unitree_g1_full_body_with_waist_height_nav_cmd": 25, + "unitree_g1_sonic": 11, + "simpler_env_google": 0, + "simpler_env_widowx": 1, + "libero_sim": 2, + "new_embodiment": 10, +} + +QWEN3_VL_2B_PROCESSOR = "Qwen/Qwen3-VL-2B-Instruct" + + +def build_processor(model_name: str, transformers_loading_kwargs: dict) -> Qwen3VLProcessor: + if Qwen3VLProcessor is None: + raise ImportError( + "Qwen3VLProcessor is not available. Please upgrade transformers: pip install transformers>=4.52.0" + ) + if model_name == "nvidia/Cosmos-Reason2-2B": + # Cosmos-Reason2-2B is a Qwen3-VL 2B backbone but the NVIDIA repo does + # not publish a Qwen3VLProcessor-compatible preprocessor. Fall back to + # the upstream Qwen3-VL repo for processor artifacts only; model + # weights are still loaded from `nvidia/Cosmos-Reason2-2B`. + logger.warning_once( + "Substituting processor from %s because %s does not ship one. " + "If you fine-tune Cosmos-Reason2-2B's tokenizer/image processor, " + "load the processor explicitly instead of relying on this fallback.", + QWEN3_VL_2B_PROCESSOR, + model_name, + ) + model_name = QWEN3_VL_2B_PROCESSOR + return Qwen3VLProcessor.from_pretrained(model_name, **transformers_loading_kwargs) + + +class Gr00tN1d7DataCollator: + def __init__( + self, + model_name: str, + model_type: str = "qwen", + transformers_loading_kwargs: dict = {}, + ): + ### We need to use the same processor for padding input ids and concat + self.processor = build_processor(model_name, transformers_loading_kwargs) + # Set padding side to 'left' for Flash Attention compatibility + self.processor.tokenizer.padding_side = "left" + self.model_type = model_type + self.model_name = model_name + + def __call__(self, features: list[dict[str, Any]]) -> BatchFeature: + batch = {} + keys = list(set().union(*(elem.keys() for elem in features))) + + for key in keys: + values = [elem[key] for elem in features if key in elem] + if key == "vlm_content": + # Handle vlm_content specially - extract text and images + text_list = [] + image_inputs = [] + for v in values: + curr_text_list = [v["text"]] + + text_list += curr_text_list + curr_image_inputs = v["images"] + image_inputs += curr_image_inputs + + vlm_inputs = self.processor( + text=text_list, + images=image_inputs, + return_tensors="pt", + padding=True, + ) + for k, v in vlm_inputs.items(): + batch[k] = v + elif key in ( + "pixel_values", + "image_grid_thw", + "attention_mask", + "input_ids", + ): + raise Exception("Not implemented") + else: + # state, state_mask, action and action_mask - stack to form batch dimension + batch[key] = torch.from_numpy(np.stack(values)) + return BatchFeature(data={"inputs": batch}) + + def __str__(self): + return f"Gr00tN1d7DataCollator(model_name={self.model_name}, model_type={self.model_type})" + + +class Gr00tN1d7Processor(BaseProcessor): + data_collator_class = Gr00tN1d7DataCollator + + def __init__( + self, + modality_configs: dict[str, dict[str, ModalityConfig]], + statistics: (dict[str, dict[str, dict[str, dict[str, list[float]]]]] | None) = None, + use_percentiles: bool = False, + clip_outliers: bool = True, + image_crop_size: list[int] = None, + image_target_size: list[int] = None, + shortest_image_edge: int = 256, + crop_fraction: float = 0.95, + random_rotation_angle: int | None = None, + color_jitter_params: dict[str, float] | None = None, + formalize_language: bool = True, + model_name: str = "nvidia/Cosmos-Reason2-2B", + model_type: str = "qwen", + max_state_dim: int = 29, + max_action_dim: int = 29, + max_action_horizon: int = 50, + apply_sincos_state_encoding: bool = False, + use_albumentations: bool = False, + extra_augmentation_config: dict | None = None, + use_relative_action: bool = False, + embodiment_id_mapping: dict[str, int] | None = None, + transformers_loading_kwargs: dict = {"trust_remote_code": True}, + # State augmentation + exclude_state: bool = False, + state_dropout_prob: float = 0.0, + # Normalization + use_mean_std: bool = False, + # Backward-compat params (stored but not actively used) + letter_box_transform: bool = False, + ): + self.modality_configs = parse_modality_configs(modality_configs) + + # Initialize StateActionProcessor for state/action normalization + self.state_action_processor = StateActionProcessor( + modality_configs=modality_configs, + statistics=statistics, + use_percentiles=use_percentiles, + clip_outliers=clip_outliers, + apply_sincos_state_encoding=apply_sincos_state_encoding, + use_relative_action=use_relative_action, + ) + + # Save state action processor settings + self.use_percentiles = use_percentiles + self.use_mean_std = use_mean_std + self.clip_outliers = clip_outliers + self.apply_sincos_state_encoding = apply_sincos_state_encoding + self.use_relative_action = use_relative_action + self.extra_augmentation_config = extra_augmentation_config + + # State augmentation settings + self.exclude_state = exclude_state + self.state_dropout_prob = state_dropout_prob + + self.letter_box_transform = letter_box_transform + + # Save VLM settings + self.formalize_language = formalize_language + self.model_name = model_name + self.model_type = model_type + + self.max_state_dim = max_state_dim + self.max_action_dim = max_action_dim + self.max_action_horizon = max_action_horizon + + # Save image processing settings + self.image_crop_size = image_crop_size + self.image_target_size = image_target_size + self.random_rotation_angle = random_rotation_angle + self.color_jitter_params = color_jitter_params + self.processor = build_processor(model_name, transformers_loading_kwargs) + # Set padding side to 'left' for Flash Attention compatibility + self.processor.tokenizer.padding_side = "left" + self.embodiment_id_mapping = embodiment_id_mapping or EMBODIMENT_TAG_TO_PROJECTOR_INDEX + # Merge any missing pre-trained embodiment tags into the custom mapping + for k, v in EMBODIMENT_TAG_TO_PROJECTOR_INDEX.items(): + if k not in self.embodiment_id_mapping: + self.embodiment_id_mapping[k] = v + self.shortest_image_edge = shortest_image_edge + self.crop_fraction = crop_fraction + + # Statistics cache (mirrors state_action_processor.statistics for serialization) + self.statistics: dict[str, dict[str, dict[str, dict[str, list[float]]]]] = {} + + # Choose between torchvision and albumentations transforms + self.use_albumentations = use_albumentations + if use_albumentations: + self.train_image_transform, self.eval_image_transform = build_image_transformations_albumentations( + image_target_size, + image_crop_size, + random_rotation_angle, + color_jitter_params, + shortest_image_edge, + crop_fraction, + extra_augmentation_config=self.extra_augmentation_config, + ) + else: + self.train_image_transform, self.eval_image_transform = build_image_transformations( + image_target_size, + image_crop_size, + random_rotation_angle, + color_jitter_params, + ) + self._collator = self.data_collator_class( + model_name=model_name, + model_type=model_type, + transformers_loading_kwargs=transformers_loading_kwargs, + ) + self.train() + + @property + def collator(self): + return self._collator + + def train(self): + super().train() + self.state_action_processor.train() + + def eval(self): + super().eval() + self.state_action_processor.eval() + + def set_statistics( + self, + statistics: dict[str, dict[str, dict[str, dict[str, list[float]]]]], + override: bool = False, + ) -> None: + """Set dataset statistics for normalization.""" + for key in statistics: + if key not in self.statistics or override: + if override: + print(f"Overriding statistics for {key}") + self.statistics[key] = deepcopy(statistics[key]) + else: + print(f"Embodiment tag {key} already in statistics, skipping updating") + + self.state_action_processor.set_statistics(statistics, override=override) + + # Compute action dimensions for convenience + self.action_dim = {} + for embodiment_tag in self.state_action_processor.statistics: + self.action_dim[embodiment_tag] = self.state_action_processor.get_action_dim(embodiment_tag) + + def decode_action( + self, + action: np.ndarray, + embodiment_tag: EmbodimentTag, + state: dict[str, np.ndarray] | None = None, + ): + """Undo action normalization and convert relative actions to absolute.""" + # Split concatenated action into joint groups + out_dict = {} + start_idx = 0 + joint_groups = self.modality_configs[embodiment_tag.value]["action"].modality_keys + action_horizon = len(self.modality_configs[embodiment_tag.value]["action"].delta_indices) + for key in joint_groups: + joint_dim = self.state_action_processor.norm_params[embodiment_tag.value]["action"][key]["dim"].item() + out_dict[key] = action[..., :action_horizon, start_idx : start_idx + joint_dim] + start_idx += joint_dim + + # Use StateActionProcessor to unnormalize and convert to absolute + return self.state_action_processor.unapply_action(out_dict, embodiment_tag.value, state=state) + + def unapply( + self, + action: np.ndarray, + embodiment_tag: EmbodimentTag, + state: dict[str, np.ndarray] | None = None, + prev_action: dict[str, np.ndarray] | None = None, + ) -> dict[str, np.ndarray]: + """Undo action normalization and convert relative to absolute. + + Args: + action: Normalized action array of shape (..., action_horizon, action_dim) + embodiment_tag: Embodiment tag + state: State observations with "state." prefixed keys (for relative actions) + prev_action: Unused (kept for API compatibility) + + Returns: + Dict mapping "action." to unnormalized (absolute) action arrays. + """ + out_dict = {} + start_idx = 0 + joint_groups = self.modality_configs[embodiment_tag.value]["action"].modality_keys + action_horizon = len(self.modality_configs[embodiment_tag.value]["action"].delta_indices) + for key in joint_groups: + joint_dim = self.state_action_processor.norm_params[embodiment_tag.value]["action"][key]["dim"].item() + out_dict[key] = action[..., :action_horizon, start_idx : start_idx + joint_dim] + start_idx += joint_dim + + # Strip "state." prefix for StateActionProcessor + stripped_state = None + if state is not None: + stripped_state = {k.replace("state.", ""): v for k, v in state.items()} + + result = self.state_action_processor.unapply_action(out_dict, embodiment_tag.value, state=stripped_state) + return {f"action.{key}": value for key, value in result.items()} + + def process_observation(self, observation: dict[str, Any], embodiment_tag: EmbodimentTag): + """Process batched observation tensors for inference. + + Args: + observation: Dict with keys like "video.", "state.", "" + Video values expected as numpy arrays of shape (B, T, H, W, C). + embodiment_tag: Embodiment tag identifying the robot configuration. + + Returns: + BatchFeature with tokenized VLM inputs, state, embodiment_id, and action_mask. + """ + modality_config = self.modality_configs[embodiment_tag.value] + transformed_observation = {} + + # Normalize states + state_keys = modality_config["state"].modality_keys + state_data = {key: observation[f"state.{key}"] for key in state_keys} + exclude_state = self.exclude_state or getattr(modality_config["state"], "exclude_state", False) + if exclude_state: + normalized_states = torch.cat( + [torch.from_numpy(np.zeros_like(state_data[key])) for key in state_keys], dim=-1 + ) + else: + norm_state_dict = self.state_action_processor.apply_state( + state=state_data, embodiment_tag=embodiment_tag.value + ) + normalized_states = torch.cat([torch.from_numpy(norm_state_dict[key]) for key in state_keys], dim=-1) + + assert normalized_states.shape[1] <= self.max_state_dim, ( + f"State dimension {normalized_states.shape[1]} exceeds max_state_dim {self.max_state_dim}" + ) + padding_shape = ( + *normalized_states.shape[:-1], + self.max_state_dim - normalized_states.shape[-1], + ) + normalized_states = torch.cat([normalized_states, torch.zeros(padding_shape)], dim=-1) + transformed_observation["state"] = normalized_states + + # Process images: observation values are (B, T, H, W, C) numpy arrays + image_keys = modality_config["video"].modality_keys + images_dict = {view: torch.from_numpy(observation[f"video.{view}"]) for view in image_keys} + images = torch.stack([images_dict[view] for view in image_keys], dim=2) # (B, T, V, H, W, C) + assert images.ndim == 6 + B, T, V, img_H, img_W, img_C = images.shape + + if self.use_albumentations: + images_flat = images.reshape(B * T * V, img_H, img_W, img_C) + pil_images = [Image.fromarray(img.numpy()) for img in images_flat] + transformed_pil, _ = apply_with_replay(self.eval_image_transform, pil_images) + transformed_stacked = torch.stack(transformed_pil) # (B*T*V, C, H_new, W_new) + _, img_C_new, img_H_new, img_W_new = transformed_stacked.shape + transformed_images = transformed_stacked.reshape(B, T * V, img_C_new, img_H_new, img_W_new).numpy() + else: + # Rearrange (B, T, V, H, W, C) to (B, T*V, C, H, W) for torchvision. + images_perm = images.permute(0, 1, 2, 5, 3, 4).reshape(B, T * V, img_C, img_H, img_W) + transformed_images = self.eval_image_transform(images_perm).numpy() + + language_key = modality_config["language"].modality_keys[0] + language = [ + re.sub(r"[^\w\s]", "", lang.lower()) if self.formalize_language else lang + for lang in observation[language_key] + ] + + texts, all_images = [], [] + for i in range(B): + vlm_inputs = self._apply_vlm_processing(transformed_images[i], language[i]) + vc = vlm_inputs["vlm_content"] + texts.append(vc["text"]) + all_images.extend(vc["images"]) + tokenized = self.processor(text=texts, images=all_images, return_tensors="pt", padding=True) + for k, v in tokenized.items(): + transformed_observation[k] = v + + embodiment_id = torch.ones(B, dtype=torch.int32) * self.embodiment_id_mapping[embodiment_tag.value] + transformed_observation["embodiment_id"] = embodiment_id + + # Action mask: shape (B, max_action_horizon), 1 in the valid horizon window + action_config = modality_config["action"] + action_horizon = len(action_config.delta_indices) + assert action_horizon <= self.max_action_horizon, ( + f"Action horizon {action_horizon} (from delta_indices) exceeds" + f" max_action_horizon {self.max_action_horizon}. Increase model config" + f" action_horizon to >= {action_horizon}." + ) + action_mask = torch.zeros((B, self.max_action_horizon), dtype=torch.float32) + if action_horizon > 0: + action_mask[:, :action_horizon] = 1.0 + transformed_observation["action_mask"] = action_mask + + return BatchFeature(transformed_observation) + + def _apply_vlm_processing(self, images: np.ndarray, language: str) -> BatchFeature: + """ + Args: + batch: + video: [T, C, H, W] + Returns: vlm_content format for collation + """ + # Convert images to PIL format + pil_images = [Image.fromarray(np.transpose(v, (1, 2, 0))) for v in images] + + # Create conversation with images and text + conversation = [ + { + "role": "user", + "content": [ + *[{"type": "image", "image": img} for img in pil_images], + {"type": "text", "text": language}, + ], + } + ] + + # Apply chat template but don't process yet - let collator handle it + text = self.processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False) + + # Return vlm_content format for collation + return { + "vlm_content": { + "text": text, + "images": pil_images, + "conversation": conversation, + } + } + + def __call__( + self, + messages: list[dict[str, Any]], + ): + assert len(messages) == 1 + content = messages[0]["content"] + embodiment_tag = content.embodiment + action_data = content.actions + state_data = content.states + + # Use StateActionProcessor to handle relative conversion and normalization + norm_state_dict, normalized_actions = self.state_action_processor.apply( + state=state_data, + action=action_data, + embodiment_tag=embodiment_tag.value, + ) + + if normalized_actions: + # Concatenate actions + action_keys = self.modality_configs[embodiment_tag.value]["action"].modality_keys + normalized_actions = torch.cat( + [torch.from_numpy(normalized_actions[key]) for key in action_keys], + dim=-1, + ) # (t, d) + action_dim = normalized_actions.shape[1] + # Pad action to max_action_dim + normalized_actions = torch.cat( + [ + normalized_actions, + torch.zeros( + normalized_actions.shape[0], + self.max_action_dim - normalized_actions.shape[1], + ), + ], + dim=-1, + ) # (t, max_action_dim) + # Pad action to max_action_horizon + action_horizon = normalized_actions.shape[0] + assert action_horizon <= self.max_action_horizon, ( + f"Action sequence length {action_horizon} exceeds max_action_horizon" + f" {self.max_action_horizon}. Increase model config action_horizon to" + f" >= {action_horizon}." + ) + normalized_actions = torch.cat( + [ + normalized_actions, + torch.zeros( + self.max_action_horizon - normalized_actions.shape[0], + self.max_action_dim, + ), + ], + dim=0, + ) # (max_action_horizon, max_action_dim) + # Create action mask + action_mask = torch.ones_like(normalized_actions) + action_mask[action_horizon:] = 0 + action_mask[:, action_dim:] = 0 + else: + assert not self.training, "Action is required in training mode" + normalized_actions = None + action_mask = None + + # Concatenate states with optional dropout/noise augmentation + state_keys = self.modality_configs[embodiment_tag.value]["state"].modality_keys + exclude_state = self.exclude_state or getattr( + self.modality_configs[embodiment_tag.value]["state"], "exclude_state", False + ) + if exclude_state or ( + self.state_dropout_prob > 0 and random.random() < self.state_dropout_prob and self.training + ): + normalized_states = torch.cat( + [torch.from_numpy(np.zeros_like(state_data[key])) for key in state_keys], dim=-1 + ) + else: + normalized_states = torch.cat([torch.from_numpy(norm_state_dict[key]) for key in state_keys], dim=-1) + normalized_states = torch.cat( + [ + normalized_states, + torch.zeros( + normalized_states.shape[0], + self.max_state_dim - normalized_states.shape[1], + ), + ], + dim=-1, + ) + + # Crop and resize images. + if self.training: + image_transform = self.train_image_transform + else: + image_transform = self.eval_image_transform + image_keys = self.modality_configs[embodiment_tag.value]["video"].modality_keys + + if self.formalize_language: + language = content.text.lower() + language = re.sub(r"[^\w\s]", "", language) + else: + language = content.text + + vlm_inputs = self._get_vlm_inputs( + image_keys=image_keys, + images=content.images, + masks=content.masks, + image_transform=image_transform, + language=language, + ) + + transformed_inputs = { + "state": normalized_states.to(torch.get_default_dtype()), + } + if normalized_actions is not None: + transformed_inputs["action"] = normalized_actions.to(torch.get_default_dtype()) + # Add VLM inputs + transformed_inputs.update(vlm_inputs) + if action_mask is not None: + transformed_inputs["action_mask"] = action_mask + transformed_inputs["embodiment_id"] = self.embodiment_id_mapping[embodiment_tag.value] + return transformed_inputs + + def _get_vlm_inputs( + self, + image_keys: list[str], + images: list[Image.Image], + masks: dict[str, list[np.ndarray]] | None, + image_transform: transforms.Compose | A.Compose, + language: str, + ): + temporal_stacked_images = {} + + if self.use_albumentations: + # Use albumentations transforms + replay = None + for view in image_keys: + assert view in images, f"{view} not in {images}" + if masks is not None: + assert view in masks, f"{view} not in masks" + view_masks = masks.get(view) if masks else None + view_images = images[view] + + # Apply transforms with replay for consistency + transformed_images, replay = apply_with_replay(image_transform, view_images, view_masks, replay) + temporal_stacked_images[view] = torch.stack(transformed_images) # (T, C, H, W) + else: + if masks is not None: + raise ValueError("Mask transforms require albumentations. Set use_albumentations_transforms=True.") + # Use torchvision transforms + for view in image_keys: + assert view in images, f"{view} not in {images}" + temporal_stacked_images[view] = torch.stack( + [image_transform(img) for img in images[view]] + ) # (T, C, H, W) + + for k, v in temporal_stacked_images.items(): + assert isinstance(k, str), f"{k} is not a string" + assert isinstance(v, torch.Tensor), f"{v} is not a torch tensor" + assert v.ndim == 4, f"{v} is not a 4D tensor" + assert v.dtype == torch.uint8, f"{v} is not a uint8 tensor" + assert v.shape[1] == 3, f"{v} is not a 3 channel tensor" + + stacked_images = ( + torch.stack([temporal_stacked_images[view] for view in image_keys], dim=1).flatten(0, 1).numpy() + ) # (T*V, C, H, W), processor expects numpy array + + vlm_inputs = self._apply_vlm_processing(stacked_images, language) + return vlm_inputs + + def save_pretrained(self, save_directory: str | Path) -> list[Path]: + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) + main_config_file = save_directory / "processor_config.json" + statistics_file = save_directory / "statistics.json" + embodiment_id_file = save_directory / "embodiment_id.json" + + config = { + "processor_class": self.__class__.__name__, + "processor_kwargs": { + "modality_configs": to_json_serializable(self.modality_configs), + # Image processing settings + "image_crop_size": self.image_crop_size, + "image_target_size": self.image_target_size, + "use_albumentations": self.use_albumentations, + "random_rotation_angle": self.random_rotation_angle, + "color_jitter_params": self.color_jitter_params, + "shortest_image_edge": self.shortest_image_edge, + "crop_fraction": self.crop_fraction, + "letter_box_transform": self.letter_box_transform, + # VLM settings + "model_name": self.model_name, + "model_type": self.model_type, + "formalize_language": self.formalize_language, + # State action dimensions + "max_state_dim": self.max_state_dim, + "max_action_dim": self.max_action_dim, + "max_action_horizon": self.max_action_horizon, + # StateActionProcessor settings + "use_percentiles": self.use_percentiles, + "use_mean_std": self.use_mean_std, + "clip_outliers": self.clip_outliers, + "apply_sincos_state_encoding": self.apply_sincos_state_encoding, + "use_relative_action": self.use_relative_action, + # State augmentation + "exclude_state": self.exclude_state, + "state_dropout_prob": self.state_dropout_prob, + }, + } + with open(main_config_file, "w") as f: + json.dump(config, f, indent=2) + # Save statistics + with open(statistics_file, "w") as f: + json.dump( + to_json_serializable(self.state_action_processor.statistics), + f, + indent=2, + ) + # Save embodiment id mapping + with open(embodiment_id_file, "w") as f: + json.dump(self.embodiment_id_mapping, f, indent=2) + return [main_config_file, statistics_file, embodiment_id_file] + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str | Path, **kwargs): + transformers_loading_kwargs = kwargs.pop("transformers_loading_kwargs", {"trust_remote_code": True}) + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + config_file = pretrained_model_name_or_path / "processor_config.json" + statistics_file = pretrained_model_name_or_path / "statistics.json" + embodiment_id_file = pretrained_model_name_or_path / "embodiment_id.json" + is_local = os.path.isdir(pretrained_model_name_or_path) + if not is_local: + config_file = Path(cached_file(pretrained_model_name_or_path, "processor_config.json")) + statistics_file = Path(cached_file(pretrained_model_name_or_path, "statistics.json")) + embodiment_id_file = Path(cached_file(pretrained_model_name_or_path, "embodiment_id.json")) + + with open(config_file) as f: + config = json.load(f) + with open(statistics_file) as f: + statistics = json.load(f) + if embodiment_id_file.exists(): + with open(embodiment_id_file) as f: + embodiment_id_mapping = json.load(f) + else: + embodiment_id_mapping = None + processor_kwargs = config["processor_kwargs"] + processor_kwargs["statistics"] = statistics + processor_kwargs["embodiment_id_mapping"] = embodiment_id_mapping + + # Backfill fields that older checkpoints may not have serialized. + # Without these, __init__ defaults silently apply - correct today but + # fragile if defaults ever change. + processor_kwargs.setdefault("model_name", "nvidia/Cosmos-Reason2-2B") + processor_kwargs.setdefault("model_type", "qwen") + processor_kwargs.setdefault("clip_outliers", True) + + # Directly override other processor kwargs + if kwargs: + # Override modality configs while keeping pretrained embodiment configs + modality_configs = kwargs.pop("modality_configs", {}) + for embodiment_tag, modality_config in modality_configs.items(): + processor_kwargs["modality_configs"][embodiment_tag] = modality_config + override_keys = [ + "random_rotation_angle", + "color_jitter_params", + "use_relative_action", + "exclude_state", + "state_dropout_prob", + "use_mean_std", + "model_name", + "model_type", + "max_action_horizon", + "max_state_dim", + "max_action_dim", + ] + for key in override_keys: + if key in kwargs: + override = kwargs.pop(key) + if override is not None: + processor_kwargs[key] = override + return cls(**processor_kwargs, transformers_loading_kwargs=transformers_loading_kwargs) + + +AutoProcessor.register("Gr00tN1d7", Gr00tN1d7Processor) diff --git a/vllm_omni/diffusion/models/gr00t/pipeline_gr00t.py b/vllm_omni/diffusion/models/gr00t/pipeline_gr00t.py new file mode 100644 index 00000000000..dd8c458d2dc --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/pipeline_gr00t.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable, Mapping +from typing import Any + +import numpy as np +import torch +from torch import nn +from vllm.logger import init_logger + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.models.gr00t.policy import Gr00tPolicy +from vllm_omni.diffusion.request import OmniDiffusionRequest + +logger = init_logger(__name__) + + +def get_gr00t_n1d7_post_process_func(od_config: OmniDiffusionConfig): + del od_config + + def post_process_func(x): + return x + + return post_process_func + + +def _to_float32_action_dict(actions: Mapping[str, Any]) -> dict[str, np.ndarray]: + converted = {str(key): np.asarray(value, dtype=np.float32) for key, value in actions.items()} + if not converted: + raise RuntimeError("GR00T policy returned an empty action dict.") + return converted + + +def _default_device() -> str: + return "cuda" if torch.cuda.is_available() else "cpu" + + +class Gr00tN1d7Pipeline(nn.Module): + """GR00T N1.7 policy pipeline backed by vLLM-Omni's local GR00T port. + + vLLM-Omni owns the serving integration: OpenPI observations arrive through + `sampling_params.extra_args["robot_obs"]`, this pipeline runs GR00T policy + inference, and actions are returned through `DiffusionOutput.multimodal_output`. + """ + + EXTRA_BODY_PARAMS = frozenset({"robot_obs"}) + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = "") -> None: + super().__init__() + self.od_config = od_config + self.prefix = prefix + self.model_path = od_config.model + self.model_config = dict(od_config.model_config or {}) + custom_args = od_config.custom_pipeline_args or {} + + default_embodiment = "OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT" + self.embodiment_tag = str( + custom_args.get("embodiment_tag") or self.model_config.get("embodiment_tag") or default_embodiment + ) + self.strict = bool(custom_args.get("strict", self.model_config.get("strict", True))) + self.device = str(custom_args.get("device") or self.model_config.get("device") or _default_device()) + + logger.info("Loading GR00T N1.7 policy from %s with embodiment_tag=%s", self.model_path, self.embodiment_tag) + self.policy = Gr00tPolicy( + model_path=self.model_path, + embodiment_tag=self.embodiment_tag, + device=self.device, + strict=self.strict, + ) + + def reset(self) -> dict[str, Any]: + reset = getattr(self.policy, "reset", None) + if callable(reset): + info = reset() + return info or {} + return {} + + @property + def weights_sources(self) -> tuple[Any, ...]: + return () + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + for _ in weights: + pass + return set() + + def _dummy_actions(self) -> dict[str, np.ndarray]: + embodiment_value = self.policy.embodiment_tag.value + action_config = self.policy.modality_configs["action"] + horizon = len(action_config.delta_indices) + norm_params = self.policy.processor.state_action_processor.norm_params[embodiment_value]["action"] + actions = {} + for key in action_config.modality_keys: + dim = norm_params[key]["dim"] + dim = int(dim.item() if hasattr(dim, "item") else dim) + actions[key] = np.zeros((1, horizon, dim), dtype=np.float32) + return actions + + @torch.inference_mode() + def forward(self, req: OmniDiffusionRequest, **kwargs) -> DiffusionOutput: + del kwargs + extra_args = getattr(req.sampling_params, "extra_args", {}) or {} + robot_obs = extra_args.get("robot_obs") + if robot_obs is None: + if getattr(req, "request_ids", None) == ["dummy_req_id"]: + return DiffusionOutput(multimodal_output={"actions": self._dummy_actions()}) + return DiffusionOutput(error="Gr00tN1d7Pipeline.forward expects sampling_params.extra_args['robot_obs'].") + if not isinstance(robot_obs, Mapping): + return DiffusionOutput(error=f"robot_obs must be a dict, got {type(robot_obs).__name__}.") + + if extra_args.get("reset"): + self.reset() + + policy_obs = _normalize_observation(robot_obs, language_key=self.policy.language_key) + result = self.policy.get_action(policy_obs) + actions = result[0] if isinstance(result, tuple) else result + if not isinstance(actions, Mapping): + return DiffusionOutput(error=f"GR00T policy returned {type(actions).__name__}; expected dict actions.") + return DiffusionOutput(multimodal_output={"actions": _to_float32_action_dict(actions)}) + + +def _normalize_observation(robot_obs: Mapping[str, Any], *, language_key: str) -> dict[str, Any]: + obs: dict[str, Any] = {} + if "video" in robot_obs: + obs["video"] = robot_obs["video"] + elif "images" in robot_obs: + obs["video"] = robot_obs["images"] + if "state" in robot_obs: + obs["state"] = robot_obs["state"] + if "language" in robot_obs: + obs["language"] = robot_obs["language"] + else: + prompt = robot_obs.get("prompt") + if prompt is not None: + obs["language"] = {language_key: [[str(prompt)]]} + return obs diff --git a/vllm_omni/diffusion/models/gr00t/policy.py b/vllm_omni/diffusion/models/gr00t/policy.py new file mode 100644 index 00000000000..4e81e7ce7b6 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/policy.py @@ -0,0 +1,717 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gr00t Policy implementation for inference. + +This module provides the core policy classes for running Gr00t models: +- Gr00tPolicy: Base policy class for model inference +- Gr00tSimPolicyWrapper: Wrapper for compatibility with existing Gr00t simulation environments +""" + +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from transformers import AutoModel, AutoProcessor + +from vllm_omni.diffusion.models.gr00t.dataio.embodiment_tags import FINETUNE_ONLY_TAGS, POSTTRAIN_TAGS, EmbodimentTag +from vllm_omni.diffusion.models.gr00t.dataio.interfaces import BaseProcessor +from vllm_omni.diffusion.models.gr00t.dataio.types import MessageType, ModalityConfig, VLAStepData +from vllm_omni.diffusion.models.gr00t.policy_base import BasePolicy, PolicyWrapper + + +def _rec_to_dtype(value: Any, dtype: torch.dtype) -> Any: + """Recursively convert floating-point tensors in nested collator output.""" + if isinstance(value, torch.Tensor): + return value.to(dtype=dtype) if torch.is_floating_point(value) else value + if isinstance(value, dict) or hasattr(value, "items"): + return {key: _rec_to_dtype(item, dtype) for key, item in value.items()} + if isinstance(value, list): + return [_rec_to_dtype(item, dtype) for item in value] + return value + + +class Gr00tPolicy(BasePolicy): + """Core policy class for Gr00t model inference. + + This policy handles the end-to-end inference pipeline: + 1. Validates input observations + 2. Processes observations with pretrained VLA processor + 3. Runs model inference + 4. Decodes and returns actions + + The policy expects observations with specific modalities (video, state, language) + and returns actions in the format defined by the model's modality configuration. + """ + + def __init__( + self, + embodiment_tag: EmbodimentTag | str, + model_path: str, + *, + device: int | str, + strict: bool = True, + ): + """Initialize the Gr00t Policy. + + Args: + embodiment_tag: The embodiment tag defining the robot/environment type. + Accepts an EmbodimentTag enum or a string (resolved case-insensitively). + model_path: Path to the pretrained model checkpoint directory + device: Device to run the model on (e.g., 'cuda:0', 0, 'cpu') + strict: Whether to enforce strict input validation (default: True) + """ + # Import these local modules to register GR00T with Hugging Face Auto classes. + from vllm_omni.diffusion.models.gr00t.modeling.gr00t_n1d7 import Gr00tN1d7 # noqa: F401 + from vllm_omni.diffusion.models.gr00t.modeling.processing_gr00t_n1d7 import Gr00tN1d7Processor # noqa: F401 + + super().__init__(strict=strict) + if isinstance(embodiment_tag, str): + embodiment_tag = EmbodimentTag.resolve(embodiment_tag) + model_dir = Path(model_path) + + # Load the pretrained model and move to target device with bfloat16 precision + model = AutoModel.from_pretrained(model_dir) + model.eval() # Set model to evaluation mode + model.to(device=device, dtype=torch.bfloat16) + self.model = model + + # Load the processor for input/output transformation. + # Training saves processor files under a "processor/" subdirectory, but + # AutoProcessor expects them at the model root. Fall back to the + # subdirectory when the root lacks a processor_config.json. + processor_dir = ( + model_dir / "processor" + if (model_dir / "processor").is_dir() and not (model_dir / "processor_config.json").exists() + else model_dir + ) + self.processor: BaseProcessor = AutoProcessor.from_pretrained(processor_dir) + self.processor.eval() + + # Store embodiment-specific configurations + self.embodiment_tag = embodiment_tag + all_modality_configs = self.processor.get_modality_configs() + if self.embodiment_tag.value not in all_modality_configs: + # Map raw checkpoint tag values to user-friendly enum names where possible. + supported_lines = [] + for tag_value in sorted(all_modality_configs.keys()): + enum_name = EmbodimentTag.reverse_lookup(tag_value) + if enum_name != tag_value: + supported_lines.append(f" {enum_name:30s} (--embodiment-tag {enum_name})") + else: + supported_lines.append(f" {tag_value:30s} (internal, no public enum)") + supported_str = "\n".join(supported_lines) + + hint = "" + if self.embodiment_tag in POSTTRAIN_TAGS: + hint = ( + f"\n\nHint: '{self.embodiment_tag.name}' is a posttrain tag that requires " + f"a finetuned checkpoint, not the base model. " + f"See the example READMEs for how to finetune and download checkpoints." + ) + elif self.embodiment_tag in FINETUNE_ONLY_TAGS: + hint = ( + f"\n\nHint: '{self.embodiment_tag.name}' is for finetuning custom robots. " + f"Use it with launch_finetune.py, not with the base model directly." + ) + + raise ValueError( + f"Embodiment tag '{self.embodiment_tag.name}' " + f"(value='{self.embodiment_tag.value}') is not supported " + f"by this checkpoint.\n\n" + f"Supported tags in this checkpoint:\n{supported_str}" + f"{hint}" + ) + self.modality_configs = { + k: v for k, v in all_modality_configs[self.embodiment_tag.value].items() if k != "rl_info" + } + self.collate_fn = self.processor.collator + + # Extract and validate language configuration + # Some embodiments (e.g. OXE_DROID) define multiple language keys for + # training-time augmentation (paraphrases). At inference we only use the first key. + language_keys = self.modality_configs["language"].modality_keys + language_delta_indices = self.modality_configs["language"].delta_indices + assert len(language_keys) >= 1, "At least one language key is required" + assert len(language_delta_indices) == 1, "Only one language delta index is supported" + self.language_key = language_keys[0] + + def _unbatch_observation(self, value: dict[str, Any]) -> list[dict[str, Any]]: + """Unbatch a batched observation into a list of single observations. + + Args: + value: Batched observation with shape (B, ...) for each modality + + Returns: + List of B observations, each with the batch dimension removed + """ + unbatched_obs = [] + # Infer batch size from the first video key + batch_size = value["video"][list(value["video"].keys())[0]].shape[0] + + # Split each modality along the batch dimension + for i in range(batch_size): + unbatched_value = { + "video": {k: v[i] for k, v in value["video"].items()}, + "state": {k: v[i] for k, v in value["state"].items()}, + "language": {k: v[i] for k, v in value["language"].items()}, + } + unbatched_obs.append(unbatched_value) + return unbatched_obs + + def _to_vla_step_data(self, observation: dict[str, Any]) -> VLAStepData: + """Convert a single observation into a VLAStepData object for processing. + + Args: + observation: Single observation dict with video, state, and language + + Returns: + VLAStepData object ready for processor input + """ + return VLAStepData( + images=observation["video"], + states=observation["state"], + actions={}, # No ground truth actions during inference + text=observation["language"][self.language_key][0], + embodiment=self.embodiment_tag, + ) + + def check_observation(self, observation: dict[str, Any]) -> None: + """Validate that the observation has the correct structure and types. + + This method ensures that all required modalities are present and that their + data types, shapes, and dimensions match the model's expectations. + + Expected observation structure: + - video: dict[str, np.ndarray[np.uint8, (B, T, H, W, C)]] + - B: batch size + - T: temporal horizon (number of frames) + - H, W: image height and width + - C: number of channels (must be 3 for RGB) + - state: dict[str, np.ndarray[np.float32, (B, T, D)]] + - B: batch size + - T: temporal horizon (number of state observations) + - D: state dimension + - language: dict[str, list[list[str]]] + - Shape: (B, T) where each element is a string + - T: temporal horizon (typically 1 for language) + + Args: + observation: Dictionary containing video, state, and language modalities + + Raises: + AssertionError: If any validation check fails + """ + # Check that observation contains all required top-level modality keys + for modality in ["video", "state", "language"]: + assert modality in observation, f"Observation must contain a '{modality}' key" + assert isinstance(observation[modality], dict), ( + f"Observation '{modality}' must be a dictionary. " + f"Got {type(observation[modality])}: {observation[modality]}" + ) + + # Track batch size across modalities to ensure consistency + bs = -1 + + # ===== VIDEO VALIDATION ===== + # Validate each video stream defined in the modality config + for video_key in self.modality_configs["video"].modality_keys: + assert video_key in observation["video"], f"Video key '{video_key}' must be in observation" + + # Set or verify batch size consistency across all video keys + if bs == -1: + bs = len(observation["video"][video_key]) + else: + assert len(observation["video"][video_key]) == bs, ( + f"Video key '{video_key}' must have batch size {bs}. Got {len(observation['video'][video_key])}" + ) + + batched_video = observation["video"][video_key] + + # Verify data type is numpy array + assert isinstance(batched_video, np.ndarray), ( + f"Video key '{video_key}' must be a numpy array. Got {type(batched_video)}" + ) + + # Verify dtype is uint8 (standard for image data, range 0-255) + assert batched_video.dtype == np.uint8, ( + f"Video key '{video_key}' must be a numpy array of type np.uint8. Got {batched_video.dtype}" + ) + + # Verify shape has 5 dimensions: (B, T, H, W, C) + assert batched_video.ndim == 5, ( + f"Video key '{video_key}' must be a numpy array of shape (B, T, H, W, C), got {batched_video.shape}" + ) + + # Verify temporal dimension matches the expected horizon from config + assert batched_video.shape[1] == len(self.modality_configs["video"].delta_indices), ( + f"Video key '{video_key}'s horizon must be " + f"{len(self.modality_configs['video'].delta_indices)}. Got {batched_video.shape[1]}" + ) + + # Verify channel dimension is 3 (RGB images) + assert batched_video.shape[-1] == 3, ( + f"Video key '{video_key}'s channel 'C' must be 3. Got {batched_video.shape[-1]}" + ) + + # ===== STATE VALIDATION ===== + # Validate each state stream defined in the modality config + for state_key in self.modality_configs["state"].modality_keys: + # Check that the expected state key exists in the observation + # Must happen before indexing; see video validation above. + assert state_key in observation["state"], f"State key '{state_key}' must be in observation" + + # Set or verify batch size consistency across all state keys + if bs == -1: + bs = len(observation["state"][state_key]) + else: + assert len(observation["state"][state_key]) == bs, ( + f"State key '{state_key}' must have batch size {bs}. Got {len(observation['state'][state_key])}" + ) + + batched_state = observation["state"][state_key] + + # Verify data type is numpy array + assert isinstance(batched_state, np.ndarray), ( + f"State key '{state_key}' must be a numpy array. Got {type(batched_state)}" + ) + + # Verify dtype is float32 (standard for continuous state values) + assert batched_state.dtype == np.float32, ( + f"State key '{state_key}' must be a numpy array of type np.float32. Got {batched_state.dtype}" + ) + + # Verify shape has 3 dimensions: (B, T, D) + assert batched_state.ndim == 3, ( + f"State key '{state_key}' must be a numpy array of shape (B, T, D), got {batched_state.shape}" + ) + + # Verify temporal dimension matches the expected horizon from config + assert batched_state.shape[1] == len(self.modality_configs["state"].delta_indices), ( + f"State key '{state_key}'s horizon must be " + f"{len(self.modality_configs['state'].delta_indices)}. Got {batched_state.shape[1]}" + ) + + # ===== LANGUAGE VALIDATION ===== + # Validate each language stream defined in the modality config + for language_key in self.modality_configs["language"].modality_keys: + # Check that the expected language key exists in the observation + # Must happen before indexing; see video validation above. + assert language_key in observation["language"], f"Language key '{language_key}' must be in observation" + + # Set or verify batch size consistency (language uses len instead of .shape) + if bs == -1: + bs = len(observation["language"][language_key]) + else: + assert len(observation["language"][language_key]) == bs, ( + f"Language key '{language_key}' must have batch size {bs}. " + f"Got {len(observation['language'][language_key])}" + ) + + batched_language: list[list[str]] = observation["language"][language_key] + + # Verify outer structure is a list (batch dimension) + assert isinstance(batched_language, list), ( + f"Language key '{language_key}' must be a list. Got {type(batched_language)}" + ) + + # Validate each batch item + for batch_item in batched_language: + # Verify temporal dimension matches expected horizon + assert len(batch_item) == len(self.modality_configs["language"].delta_indices), ( + f"Language key '{language_key}'s horizon must be " + f"{len(self.modality_configs['language'].delta_indices)}. Got {len(batched_language)}" + ) + + # Verify inner structure is also a list (temporal dimension) + assert isinstance(batch_item, list), f"Language batch item must be a list. Got {type(batch_item)}" + + # Current implementation expects exactly one language instruction per timestep + assert len(batch_item) == 1, f"Language batch item must have exactly one item. Got {len(batch_item)}" + + # Verify the instruction itself is a string + assert isinstance(batch_item[0], str), ( + f"Language batch item must be a string. Got {type(batch_item[0])}" + ) + + def _get_action( + self, observation: dict[str, Any], options: dict[str, Any] | None = None + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Internal method to compute actions from observations. + + Pipeline: + 1. Unbatch observations into individual samples + 2. Convert each to VLAStepData and process + 3. Collate into model input batch + 4. Run model inference + 5. Decode and unnormalize actions + + Args: + observation: Batched observation dictionary + options: Optional parameters (currently unused) + + Returns: + Tuple of (actions_dict, info_dict) + """ + # Step 1: Split batched observation into individual observations + unbatched_observations = self._unbatch_observation(observation) + processed_inputs = [] + + # Step 2: Process each observation through the VLA processor + states = [] + for obs in unbatched_observations: + vla_step_data = self._to_vla_step_data(obs) + states.append(vla_step_data.states) # dict[str, np.ndarray[np.float32, (T, D)]] + messages = [{"type": MessageType.EPISODE_STEP.value, "content": vla_step_data}] + processed_inputs.append(self.processor(messages)) + + # Step 3: Collate processed inputs into a single batch for model + collated_inputs = self.collate_fn(processed_inputs) + collated_inputs = _rec_to_dtype(collated_inputs, dtype=torch.bfloat16) + + # Step 4: Run model inference to predict actions + with torch.inference_mode(): + model_pred = self.model.get_action(**collated_inputs) + normalized_action = model_pred["action_pred"].float() + + # Step 5: Decode actions from normalized space back to physical units + batched_states = {} + for k in self.modality_configs["state"].modality_keys: + batched_states[k] = np.stack([s[k] for s in states], axis=0) # (B, T, D) + unnormalized_action = self.processor.decode_action( + normalized_action.cpu().numpy(), self.embodiment_tag, batched_states + ) + + # Cast all actions to float32 for consistency + casted_action = {key: value.astype(np.float32) for key, value in unnormalized_action.items()} + return casted_action, {} + + def check_action(self, action: dict[str, Any]) -> None: + """Validate that the action has the correct structure and types. + + This method ensures that all required action keys are present and that their + data types, shapes, and dimensions match the model's action space. + + Expected action structure: + - action: dict[str, np.ndarray[np.float32, (B, T, D)]] + - B: batch size + - T: action horizon (number of future action steps) + - D: action dimension (e.g., joint positions, velocities, gripper state) + + Args: + action: Dictionary containing action arrays for each action key + + Raises: + AssertionError: If any validation check fails + """ + # Validate each action key defined in the modality config + for action_key in self.modality_configs["action"].modality_keys: + # Check that the expected action key exists + assert action_key in action, f"Action key '{action_key}' must be in action" + + action_arr = action[action_key] + + # Verify data type is numpy array + assert isinstance(action_arr, np.ndarray), ( + f"Action key '{action_key}' must be a numpy array. Got {type(action_arr)}" + ) + + # Verify dtype is float32 (standard for continuous actions) + assert action_arr.dtype == np.float32, ( + f"Action key '{action_key}' must be a numpy array of type np.float32. Got {action_arr.dtype}" + ) + + # Verify shape has 3 dimensions: (B, T, D) + assert action_arr.ndim == 3, ( + f"Action key '{action_key}' must be a numpy array of shape (B, T, D), got {action_arr.shape}" + ) + + # Verify action horizon matches the expected temporal dimension from config + assert action_arr.shape[1] == len(self.modality_configs["action"].delta_indices), ( + f"Action key '{action_key}'s horizon must be " + f"{len(self.modality_configs['action'].delta_indices)}. Got {action_arr.shape[1]}" + ) + + def get_modality_config(self) -> dict[str, ModalityConfig]: + return self.modality_configs + + def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]: + """Reset the policy to its initial state. + + Args: + options: Dictionary containing the options for the reset + + Returns: + Dictionary containing the info after resetting the policy + """ + return {} + + +class Gr00tSimPolicyWrapper(PolicyWrapper): + """Wrapper for Gr00tPolicy to enable compatibility with existing Gr00t simulation environments. + + This wrapper is specifically designed for retro-fitting the Gr00t policy with the current + Gr00t simulation environment interface. It handles the transformation between the flat + observation format used by Gr00t sim environments (with keys like 'video.camera_name', + 'state.joint_positions') and the nested format expected by Gr00tPolicy. + + **Important**: If you are using other environments, custom robots, or building new environments, + you should use `Gr00tPolicy` directly and format your observations according to its interface. + This wrapper is only needed for compatibility with the existing Gr00t sim infrastructure. + + Key transformations performed by this wrapper: + - Observation keys: 'video.cam' -> observation['video']['cam'] + - Observation keys: 'state.joints' -> observation['state']['joints'] + - Language keys: 'task' or 'annotation.human.coarse_action' -> observation['language']['task'] + - Action keys: action['joints'] -> 'action.joints' + """ + + def __init__(self, policy: Gr00tPolicy, *, strict: bool = True): + """Initialize the wrapper around a Gr00tPolicy instance. + + Args: + policy: The Gr00tPolicy instance to wrap + strict: Whether to enforce strict validation (default: True) + """ + super().__init__(policy, strict=strict) + self.policy: Gr00tPolicy = policy + assert len(self.policy.modality_configs["language"].delta_indices) == 1, ( + "Only one language delta index is supported" + ) + + def check_observation(self, observation: dict[str, Any]) -> None: + """Validate observation from Gr00t sim environment format. + + This validation is specific to the flat observation format used by Gr00t sim environments. + Unlike Gr00tPolicy.check_observation which expects nested dicts, this expects flat keys. + + Expected observation structure (Gr00t sim format): + - Flat keys like 'video.camera_name': np.ndarray[np.uint8, (B, T, H, W, C)] + - Flat keys like 'state.state_name': np.ndarray[np.float32, (B, T, D)] + - Language keys: tuple[str] or list[str] with shape (B,) + - Key can be 'task' or 'annotation.human.coarse_action' (for DC envs) + + Args: + observation: Flat observation dictionary from Gr00t sim environment + + Raises: + AssertionError: If any validation check fails + """ + modality_configs = self.get_modality_config() + + # ===== VIDEO VALIDATION ===== + # Check video modalities with flat key format: 'video.camera_name' + for video_key in modality_configs["video"].modality_keys: + # Construct flat key expected in Gr00t sim environment + parsed_key = f"video.{video_key}" + assert parsed_key in observation, f"Video key '{parsed_key}' must be in observation" + + batched_video = observation[parsed_key] + + # Verify data type is numpy array + assert isinstance(batched_video, np.ndarray), ( + f"Video key '{video_key}' must be a numpy array. Got {type(batched_video)}" + ) + + # Verify dtype is uint8 (standard for image data, range 0-255) + assert batched_video.dtype == np.uint8, ( + f"Video key '{video_key}' must be a numpy array of type np.uint8. Got {batched_video.dtype}" + ) + + # Verify shape has 5 dimensions: (B, T, H, W, C) + assert batched_video.ndim == 5, ( + f"Video key '{video_key}' must be a numpy array of shape (B, T, H, W, C), got {batched_video.shape}" + ) + + # Verify temporal dimension matches the expected horizon from config + assert batched_video.shape[1] == len(modality_configs["video"].delta_indices), ( + f"Video key '{video_key}'s horizon must be " + f"{len(modality_configs['video'].delta_indices)}. Got {batched_video.shape[1]}" + ) + + # Verify channel dimension is 3 (RGB images) + assert batched_video.shape[-1] == 3, ( + f"Video key '{video_key}'s channel 'C' must be 3. Got {batched_video.shape[-1]}" + ) + + # ===== STATE VALIDATION ===== + # Check state modalities with flat key format: 'state.state_name' + for state_key in modality_configs["state"].modality_keys: + # Construct flat key expected in Gr00t sim environment + parsed_key = f"state.{state_key}" + assert parsed_key in observation, f"State key '{parsed_key}' must be in observation" + + batched_state = observation[parsed_key] + + # Verify data type is numpy array + assert isinstance(batched_state, np.ndarray), ( + f"State key '{state_key}' must be a numpy array. Got {type(batched_state)}" + ) + + # Verify dtype is float32 (standard for continuous state values) + assert batched_state.dtype == np.float32, ( + f"State key '{state_key}' must be a numpy array of type np.float32. Got {batched_state.dtype}" + ) + + # Verify shape has 3 dimensions: (B, T, D) + assert batched_state.ndim == 3, ( + f"State key '{state_key}' must be a numpy array of shape (B, T, D), got {batched_state.shape}" + ) + + # Verify temporal dimension matches the expected horizon from config + assert batched_state.shape[1] == len(modality_configs["state"].delta_indices), ( + f"State key '{state_key}'s horizon must be " + f"{len(modality_configs['state'].delta_indices)}. Got {batched_state.shape[1]}" + ) + + # ===== LANGUAGE VALIDATION ===== + # Check language modalities (special handling for DC environment compatibility) + for language_key in modality_configs["language"].modality_keys: + # PATCH: Legacy compatibility for DC environments + # DC envs use 'annotation.human.coarse_action' instead of 'task' + if language_key == "task" and "annotation.human.coarse_action" in observation: + language_key = "annotation.human.coarse_action" + # /PATCH + + # Check that the expected language key exists + assert language_key in observation, f"Language key '{language_key}' must be in observation" + + # In Gr00t sim format, language is a tuple of strings (B,) + batched_language: tuple[str] | list[str] = observation[language_key] # (B,) + + # Verify outer structure is a tuple (batch dimension) + assert isinstance(batched_language, (tuple, list)), ( + f"Language key '{language_key}' must be a tuple or list. Got {type(batched_language)}" + ) + + # Verify each batch item is a string + assert isinstance(batched_language[0], str), ( + f"Language batch item must be a string. Got {type(batched_language[0])}" + ) + + def _get_action( + self, observation: dict[str, Any], options: dict[str, Any] | None = None + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Transform Gr00t sim observation format and compute actions. + + This method transforms the flat observation format from Gr00t sim environments + into the nested format expected by Gr00tPolicy, computes actions, and transforms + them back to the flat format expected by Gr00t sim environments. + + Input format (Gr00t sim): + - Flat keys: 'video.camera_name', 'state.state_name' + - Language: tuple[str] (B,) + + Output format (Gr00t sim): + - Flat keys: 'action.action_name' + + Args: + observation: Flat observation dictionary from Gr00t sim environment + options: Optional parameters (currently unused) + + Returns: + Tuple of (flat_actions_dict, info_dict) + """ + # Transform flat observation format to nested format expected by Gr00tPolicy + new_obs = {} + for modality in ["video", "state", "language"]: + new_obs[modality] = {} + for key in self.policy.modality_configs[modality].modality_keys: + if modality == "language": + # PATCH: Legacy compatibility for DC environments + if key == "task" and "annotation.human.coarse_action" in observation: + parsed_key = "annotation.human.coarse_action" + # /PATCH + else: + parsed_key = key + else: + # Construct flat key (e.g., 'video.camera' or 'state.joints') + parsed_key = f"{modality}.{key}" + + arr = observation[parsed_key] + + # Transform to nested format + if modality == "language": + # Convert from tuple[str] or list[str] (B,) to list[list[str]] (B, 1) + # Each element becomes a list with one string for temporal dimension + new_obs[modality][key] = [[str(item)] for item in arr] + else: + # Video and state arrays are already in correct format (B, T, ...) + new_obs[modality][key] = arr + + # Compute actions using the underlying Gr00tPolicy + action, info = self.policy.get_action(new_obs, options) + + # Transform actions back to flat format for Gr00t sim environment + # action['joints'] -> 'action.joints' + return {f"action.{key}": action[key] for key in action}, info + + def check_action(self, action: dict[str, Any]) -> None: + """Validate action in Gr00t sim environment format. + + This validation is specific to the flat action format used by Gr00t sim environments. + Unlike Gr00tPolicy.check_action which expects nested dicts, this expects flat keys. + + Expected action structure (Gr00t sim format): + - Flat keys like 'action.action_name': np.ndarray[np.float32, (B, T, D)] + - B: batch size + - T: action horizon (number of future action steps) + - D: action dimension + + Args: + action: Flat action dictionary for Gr00t sim environment + + Raises: + AssertionError: If any validation check fails + """ + modality_configs = self.get_modality_config() + + # Validate each action key defined in the modality config + for action_key in modality_configs["action"].modality_keys: + # Construct flat key expected in Gr00t sim environment (e.g., 'action.joints') + parsed_key = f"action.{action_key}" + assert parsed_key in action, f"Action key '{parsed_key}' must be in action" + + action_arr = action[parsed_key] + + # Verify data type is numpy array + assert isinstance(action_arr, np.ndarray), ( + f"Action key '{action_key}' must be a numpy array. Got {type(action_arr)}" + ) + + # Verify dtype is float32 (standard for continuous actions) + assert action_arr.dtype == np.float32, ( + f"Action key '{action_key}' must be a numpy array of type np.float32. Got {action_arr.dtype}" + ) + + # Verify shape has 3 dimensions: (B, T, D) + assert action_arr.ndim == 3, ( + f"Action key '{action_key}' must be a numpy array of shape (B, T, D), got {action_arr.shape}" + ) + + # Verify action horizon matches the expected temporal dimension from config + assert action_arr.shape[1] == len(modality_configs["action"].delta_indices), ( + f"Action key '{action_key}'s horizon must be " + f"{len(modality_configs['action'].delta_indices)}. Got {action_arr.shape[1]}" + ) + + def get_modality_config(self) -> dict[str, ModalityConfig]: + """Get the modality configuration from the underlying policy. + + Returns: + Dictionary mapping modality names to their configurations + """ + return self.policy.get_modality_config() diff --git a/vllm_omni/diffusion/models/gr00t/policy_base.py b/vllm_omni/diffusion/models/gr00t/policy_base.py new file mode 100644 index 00000000000..bab968f9de3 --- /dev/null +++ b/vllm_omni/diffusion/models/gr00t/policy_base.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any + + +class BasePolicy(ABC): + """Abstract base class for robotic control policies. + + This class defines the interface that all policies must implement, including + methods for action computation, input/output validation, and state management. + + Subclasses must implement: + - check_observation(): Validate observation format + - check_action(): Validate action format + - _get_action(): Core action computation logic + - reset(): Reset policy to initial state + """ + + def __init__(self, *, strict: bool = True): + self.strict = strict + + @abstractmethod + def check_observation(self, observation: dict[str, Any]) -> None: + """Check if the observation is valid. + + Args: + observation: Dictionary containing the current state/observation of the environment + + Raises: + AssertionError: If the observation is invalid. + """ + pass + + @abstractmethod + def check_action(self, action: dict[str, Any]) -> None: + """Check if the action is valid. + + Args: + action: Dictionary containing the action to be executed + + Raises: + AssertionError: If the action is invalid. + """ + pass + + @abstractmethod + def _get_action( + self, observation: dict[str, Any], options: dict[str, Any] | None = None + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Compute and return the next action based on current observation. + + This method should be overridden by subclasses to implement policy-specific + action computation. Input validation is handled by the public get_action() method. + + Args: + observation: Dictionary containing the current state/observation + options: Optional configuration dict for action computation + + Returns: + Tuple of (action, info): + - action: Dictionary containing the action to be executed + - info: Dictionary containing additional metadata (e.g., confidence scores) + """ + pass + + def get_action( + self, observation: dict[str, Any], options: dict[str, Any] | None = None + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Compute and return the next action based on current observation with validation. + + This is the main public interface. It validates the observation, calls + the internal _get_action(), and validates the resulting action. + + Args: + observation: Dictionary containing the current state/observation + options: Optional configuration dict for action computation + + Returns: + Tuple of (action, info): + - action: Dictionary containing the validated action + - info: Dictionary containing additional metadata + + Raises: + AssertionError/ValueError: If observation or action validation fails + """ + if self.strict: + self.check_observation(observation) + action, info = self._get_action(observation, options) + if self.strict: + self.check_action(action) + return action, info + + @abstractmethod + def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]: + """Reset the policy to its initial state. + + Args: + options: Dictionary containing the options for the reset + + Returns: + Dictionary containing the info after resetting the policy + """ + pass + + +class PolicyWrapper(BasePolicy): + """Base wrapper class for composing policy behaviors. + + Note: This base implementation only forwards reset(). Subclasses should + implement validation logic and additional functionality as needed. + """ + + def __init__(self, policy: BasePolicy, *, strict: bool = True): + super().__init__(strict=strict) + self.policy = policy + + def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]: + return self.policy.reset(options) diff --git a/vllm_omni/diffusion/models/internvla_a1/adapter_qwen3_vl.py b/vllm_omni/diffusion/models/internvla_a1/adapter_qwen3_vl.py index a820fb56aec..06cc8bd5195 100644 --- a/vllm_omni/diffusion/models/internvla_a1/adapter_qwen3_vl.py +++ b/vllm_omni/diffusion/models/internvla_a1/adapter_qwen3_vl.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import torch import torch.nn as nn from transformers.models.qwen3_vl.modeling_qwen3_vl import ( @@ -15,9 +13,7 @@ Qwen3VLVisionModel, Unpack, apply_rotary_pos_emb, - check_model_inputs, create_causal_mask, - deprecate_kwarg, eager_attention_forward, ) from transformers.models.qwen3_vl.modeling_qwen3_vl import ( @@ -55,7 +51,6 @@ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): self.q_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) - @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -86,9 +81,9 @@ def forward( key_states = torch.cat([past_key_values[self.layer_idx][0], key_states], dim=2) value_states = torch.cat([past_key_values[self.layer_idx][1], value_states], dim=2) - attention_interface = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -134,7 +129,6 @@ def __init__(self, config: Qwen3VLTextConfig): self.post_init() - @check_model_inputs def forward( self, input_ids: torch.LongTensor = None, @@ -176,7 +170,7 @@ def forward( attention_mask = create_causal_mask( config=self.config, - input_embeds=inputs_embeds, + inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index a381640b507..6ce70dd2b3a 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -151,6 +151,11 @@ "pipeline_internvla_a1", "InternVLAA1Pipeline", ), + "Gr00tN1d7Pipeline": ( + "gr00t", + "pipeline_gr00t", + "Gr00tN1d7Pipeline", + ), "LongCatImageEditPipeline": ( "longcat_image", "pipeline_longcat_image_edit", @@ -473,6 +478,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "BagelPipeline": "get_bagel_post_process_func", "MingImagePipeline": "get_ming_image_post_process_func", "InternVLAA1Pipeline": "get_internvla_a1_post_process_func", + "Gr00tN1d7Pipeline": "get_gr00t_n1d7_post_process_func", "LongCatImageEditPipeline": "get_longcat_image_post_process_func", "StableDiffusion3Pipeline": "get_sd3_image_post_process_func", "FluxKontextPipeline": "get_flux_kontext_post_process_func", diff --git a/vllm_omni/model_executor/models/gr00t/__init__.py b/vllm_omni/model_executor/models/gr00t/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/vllm_omni/model_executor/models/gr00t/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/model_executor/models/gr00t/pipeline.py b/vllm_omni/model_executor/models/gr00t/pipeline.py new file mode 100644 index 00000000000..38a76f29769 --- /dev/null +++ b/vllm_omni/model_executor/models/gr00t/pipeline.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""GR00T N1.7 single-stage policy topology.""" + +from vllm_omni.config.stage_config import ( + PipelineConfig, + StageExecutionType, + StagePipelineConfig, +) + +GR00T_N1D7_PIPELINE = PipelineConfig( + model_type="Gr00tN1d7", + model_arch="Gr00tN1d7Pipeline", + hf_architectures=("Gr00tN1d7",), + stages=( + StagePipelineConfig( + stage_id=0, + model_stage="diffusion", + execution_type=StageExecutionType.DIFFUSION, + input_sources=(), + final_output=True, + final_output_type="actions", + model_arch="Gr00tN1d7Pipeline", + ), + ), +) From f4d1c415fc95d04b70541a8a69e5c6380440a4ba Mon Sep 17 00:00:00 2001 From: Zhengyuan Su Date: Fri, 22 May 2026 12:33:45 +0700 Subject: [PATCH 2/8] feat(gr00t): add OpenPI client, MolmoSpaces eval, and server launcher for GR00T-N1.7 Signed-off-by: Zhengyuan Su --- examples/online_serving/gr00t/README.md | 109 ++++++- .../gr00t/molmospace_gr00t_eval_demo.py | 109 +++++++ .../online_serving/gr00t/openpi_client.py | 308 ++++++++++++++++++ examples/online_serving/gr00t/run_server.sh | 25 ++ 4 files changed, 538 insertions(+), 13 deletions(-) create mode 100644 examples/online_serving/gr00t/molmospace_gr00t_eval_demo.py create mode 100644 examples/online_serving/gr00t/openpi_client.py create mode 100755 examples/online_serving/gr00t/run_server.sh diff --git a/examples/online_serving/gr00t/README.md b/examples/online_serving/gr00t/README.md index 38db1a511a1..fa66f5abc20 100644 --- a/examples/online_serving/gr00t/README.md +++ b/examples/online_serving/gr00t/README.md @@ -1,33 +1,116 @@ -# GR00T OpenPI Serving +# GR00T-N1.7 OpenPI Example -This example serves NVIDIA Isaac GR00T N1.7 through the OpenPI-compatible robot websocket endpoint. +This example shows how to serve NVIDIA Isaac GR00T-N1.7 with `vllm serve --omni` +and connect a compatible OpenPI websocket client. The deployment is configured +for the DROID embodiment (`OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT`) and exposes +the GR00T action keys `eef_9d`, `gripper_position`, and `joint_position`. -## Requirements +## Files -- Install the OpenPI client dependency used by the websocket protocol: +- `run_server.sh`: launch GR00T-N1.7 OpenPI serving +- `openpi_client.py`: minimal websocket client that sends a synthetic GR00T observation +- `molmospace_gr00t_eval_demo.py`: MolmoSpaces evaluation demo wrapper + +## Environment requirements + +- `run_server.sh`, `vllm serve`, and `openpi_client.py`: use the local + `vllm-omni` environment. +- `openpi_client.py` extra deps: ```bash -pip install openpi-client +pip install openpi-client websockets ``` +Optional MolmoSpaces dependencies: + +- `molmospace_gr00t_eval_demo.py` requires a working MolmoSpaces checkout with + the `examples.gr00t_openpi` package importable on `PYTHONPATH`, and uses EGL + for MuJoCo. The default asset cache location is + `$HOME/.cache/molmospaces/gr00t-assets`. + ## Start the server From the repository root: ```bash -vllm serve nvidia/GR00T-N1.7-3B \ - --omni \ - --stage-configs-path vllm_omni/deploy/Gr00tN1d7.yaml +CUDA_VISIBLE_DEVICES=0 \ +examples/online_serving/gr00t/run_server.sh ``` +The launcher honors: + +- `CUDA_VISIBLE_DEVICES`: GPU selection (default `0`) +- `DEPLOY_CONFIG`: stage config path (default `vllm_omni/deploy/Gr00tN1d7.yaml`) +- `HOST` / `PORT`: bind address (defaults `127.0.0.1:8000`) +- `SERVED_MODEL_NAME`: alias served via the OpenAI route (default `gr00t-n1d7`) +- `VLLM_WORKER_MULTIPROC_METHOD`: defaults to `spawn` + The websocket endpoint is: -```text -ws://127.0.0.1:8000/v1/realtime/robot/openpi +- `ws://127.0.0.1:8000/v1/realtime/robot/openpi` + +The server handshake advertises the configured embodiment, image resolution +(`[180, 320]`), action horizon, and action key set through `policy_server_config` +in `vllm_omni/deploy/Gr00tN1d7.yaml`. + +## Run the client + +From the repository root: + +Environment: + +- run this in the `vllm-omni` repo environment +- if imports are missing, install `openpi-client` and `websockets` + +```bash +python examples/online_serving/gr00t/openpi_client.py \ + --host 127.0.0.1 \ + --port 8000 ``` -## Request and response +The client sends one synthetic two-frame observation crafted for GR00T's DROID +contract: + +- exterior + wrist video tensors shaped `(1, 2, 180, 320, 3)` with a slowly + varying gradient pattern (constant frames have triggered SVD failures in the + Qwen3-VL backbone in practice) +- `state` with `eef_9d` (1,1,9), `gripper_position` (1,1,1), and + `joint_position` (1,1,7) +- `language` keyed by `annotation.language.language_instruction` + +It validates: + +- GR00T metadata contract (`image_resolution`, `action_horizon`, `action_keys`, + `embodiment_tag`) +- presence and 3D shape of each expected action key +- finite action values +- post-call reset response + +## Run MolmoSpaces benchmarks + +`molmospace_gr00t_eval_demo.py` evaluates GR00T-N1.7 through the same vLLM +OpenPI server on MolmoSpaces benchmarks. Install MolmoSpaces and make sure +`examples.gr00t_openpi.gr00t_openpi_policy` is importable, then run: + +```bash +BENCH="${MOLMOSPACES_BENCHMARK_DIR}/20260327/ithor/FrankaCloseHardBench" +BENCH="${BENCH}/FrankaCloseHardBench_20260206_json_benchmark" +PYTHONPATH="${MOLMOSPACES_ROOT}" \ +MUJOCO_EGL_DEVICE_ID=0 \ +python examples/online_serving/gr00t/molmospace_gr00t_eval_demo.py \ + --host 127.0.0.1 \ + --port 8000 \ + --benchmark_dir "${BENCH}" \ + --output_dir outputs/gr00t/molmospaces \ + --max_episodes 1 \ + --task_horizon_steps 240 +``` -The OpenPI serving layer forwards raw robot observations through `sampling_params.extra_args["robot_obs"]`. The GR00T pipeline converts the observation to the local GR00T policy input shape and returns `multimodal_output["actions"]` as `dict[str, np.ndarray]`. +`PYTHONPATH` must include the MolmoSpaces workspace root so that +`examples.gr00t_openpi.gr00t_openpi_policy` is importable. +`MUJOCO_EGL_DEVICE_ID` selects the EGL device for headless rendering; set +it to match the physical GPU index (may differ from `CUDA_VISIBLE_DEVICES`). -The server handshake advertises the configured embodiment and action schema through `policy_server_config` in `vllm_omni/deploy/Gr00tN1d7.yaml`. +The wrapper subclasses `Gr00tOpenPIPolicyConfig` / `Gr00tOpenPIEvalConfig` from +the MolmoSpaces workspace and overrides the policy backend host/port at runtime +so the eval harness talks to your local vLLM-Omni GR00T server. diff --git a/examples/online_serving/gr00t/molmospace_gr00t_eval_demo.py b/examples/online_serving/gr00t/molmospace_gr00t_eval_demo.py new file mode 100644 index 00000000000..a60c01be9b0 --- /dev/null +++ b/examples/online_serving/gr00t/molmospace_gr00t_eval_demo.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Drive MolmoSpaces evaluation against a running vLLM GR00T-N1.7 server. + +Wires the MolmoSpaces evaluation harness to the ``Gr00tOpenPIEvalConfig`` / +``Gr00tOpenPIPolicyConfig`` defined under +``examples.gr00t_openpi.gr00t_openpi_policy`` in the MolmoSpaces workspace and +points its policy backend at the local vLLM-Omni websocket server. +""" + +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + +# Remove this script's directory from sys.path so that `import openpi_client` +# resolves to the installed openpi-client package rather than our local +# openpi_client.py (same guard used in openpi_client.py). +_example_dir = str(Path(__file__).resolve().parent) +if sys.path and sys.path[0] == _example_dir: + sys.path.pop(0) + +os.environ.setdefault("MUJOCO_GL", "egl") +os.environ.setdefault("PYOPENGL_PLATFORM", "egl") +os.environ.setdefault("MLSPACES_ASSETS_DIR", str(Path.home() / ".cache" / "molmospaces" / "gr00t-assets")) + +_DEMO_HOST = os.environ.get("VLLM_OMNI_DEMO_HOST", "127.0.0.1") +_DEMO_PORT = int(os.environ.get("VLLM_OMNI_DEMO_PORT", "8000")) + +# Import the GR00T MolmoSpaces base configs at module top level so the +# subclasses below are pickle-resolvable (worker processes import this module +# fresh via __main__). +from examples.gr00t_openpi.gr00t_openpi_policy import ( # noqa: E402 + Gr00tOpenPIEvalConfig, + Gr00tOpenPIPolicyConfig, +) + + +class Gr00tVllmOmniPolicyConfig(Gr00tOpenPIPolicyConfig): + host: str = _DEMO_HOST + port: int = _DEMO_PORT + + +class Gr00tVllmOmniEvalConfig(Gr00tOpenPIEvalConfig): + policy_config: Gr00tVllmOmniPolicyConfig = Gr00tVllmOmniPolicyConfig() + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--benchmark_dir", + required=True, + help=( + "Path to a MolmoSpaces benchmark directory, for example " + "$MOLMOSPACES_BENCHMARK_DIR/20260327/ithor/FrankaCloseHardBench/" + "FrankaCloseHardBench_20260206_json_benchmark" + ), + ) + parser.add_argument("--max_episodes", type=int, default=1) + parser.add_argument("--task_horizon_steps", type=int, default=80) + parser.add_argument( + "--output_dir", + required=True, + help="Directory to write evaluation outputs (created if missing).", + ) + parser.add_argument("--episode_idx", type=int, default=None) + args = parser.parse_args() + + os.environ["VLLM_OMNI_DEMO_HOST"] = args.host + os.environ["VLLM_OMNI_DEMO_PORT"] = str(args.port) + Gr00tVllmOmniPolicyConfig.model_fields["host"].default = args.host + Gr00tVllmOmniPolicyConfig.model_fields["port"].default = args.port + + # Import after env vars are set so MuJoCo picks EGL. + from molmo_spaces.evaluation import run_evaluation + + cfg_cls = Gr00tVllmOmniEvalConfig + + output_dir = args.output_dir + Path(output_dir).mkdir(parents=True, exist_ok=True) + + print(f"[eval] benchmark_dir={args.benchmark_dir}") + print(f"[eval] max_episodes={args.max_episodes} task_horizon_steps={args.task_horizon_steps}") + print(f"[eval] remote policy: ws://{args.host}:{args.port}/v1/realtime/robot/openpi") + + results = run_evaluation( + eval_config_cls=cfg_cls, + benchmark_dir=Path(args.benchmark_dir), + max_episodes=args.max_episodes, + task_horizon_steps=args.task_horizon_steps, + num_workers=1, + use_wandb=False, + output_dir=output_dir, + episode_idx=args.episode_idx, + ) + + print(f"[eval] success={results.success_count}/{results.total_count} ({results.success_rate:.1%})") + print(f"[eval] output_dir={results.output_dir}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/online_serving/gr00t/openpi_client.py b/examples/online_serving/gr00t/openpi_client.py new file mode 100644 index 00000000000..1661b310f37 --- /dev/null +++ b/examples/online_serving/gr00t/openpi_client.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Minimal GR00T-N1.7 OpenPI websocket client demo. + +Sends a single synthetic observation crafted for the GR00T DROID embodiment +contract advertised by ``vllm_omni/deploy/Gr00tN1d7.yaml``: + +- two-frame video history at ``(180, 320)`` per camera (exterior + wrist) +- ``state`` with ``eef_9d`` (1,1,9), ``gripper_position`` (1,1,1) and + ``joint_position`` (1,1,7) +- ``language`` keyed by ``annotation.language.language_instruction`` + +Expects the server to return an action dict with keys +``{"eef_9d", "gripper_position", "joint_position"}``. + +The synthetic image uses a slowly varying gradient rather than a constant. +Constant frames have caused SVD failures in the Qwen3-VL visual backbone +during local smoke testing, so we deliberately keep some spatial variance. +""" + +from __future__ import annotations + +import argparse +import sys +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np +from vllm.logger import init_logger + +try: + import websockets.sync.client +except ImportError as exc: # pragma: no cover - runtime dependency guard + raise ImportError("GR00T OpenPI example requires `websockets`.") from exc + +# NOTE: this directory does NOT contain a local file named ``openpi_client.py`` +# clashing with the installed package (this file imports nothing from it), but +# we apply the same defensive ``sys.path`` rewrite so the script works even +# when launched as ``python openpi_client.py`` from inside this directory. +try: + example_dir = str(Path(__file__).resolve().parent) + removed_path = False + if sys.path and sys.path[0] == example_dir: + sys.path.pop(0) + removed_path = True + try: + from openpi_client import msgpack_numpy + finally: + if removed_path: + sys.path.insert(0, example_dir) +except ImportError as exc: # pragma: no cover - runtime dependency guard + raise ImportError("GR00T OpenPI example requires `openpi-client`.") from exc + +logger = init_logger(__name__) + +PING_INTERVAL_SECS = 300 +PING_TIMEOUT_SECS = 3600 +DEFAULT_HOST = "127.0.0.1" +DEFAULT_PORT = 8000 +DEFAULT_PATH = "/v1/realtime/robot/openpi" +DEFAULT_PROMPT = "pick up the object and place it in the bin" +LANGUAGE_KEY = "annotation.language.language_instruction" +EXPECTED_ACTION_KEYS = ("eef_9d", "gripper_position", "joint_position") +IMAGE_HEIGHT = 180 +IMAGE_WIDTH = 320 + + +@dataclass(frozen=True) +class Gr00tServerMetadata: + image_resolution: tuple[int, int] + action_horizon: int + action_keys: tuple[str, ...] + embodiment_tag: str + needs_session_id: bool + + @classmethod + def from_dict(cls, payload: dict[str, Any]) -> Gr00tServerMetadata: + required_keys = ( + "image_resolution", + "action_horizon", + "action_keys", + "embodiment_tag", + "needs_session_id", + ) + missing_keys = [key for key in required_keys if key not in payload] + if missing_keys: + raise ValueError(f"Missing GR00T metadata keys: {missing_keys}") + + image_resolution = payload["image_resolution"] + if not isinstance(image_resolution, (list, tuple)) or len(image_resolution) != 2: + raise ValueError(f"Invalid image_resolution: {image_resolution!r}") + + return cls( + image_resolution=(int(image_resolution[0]), int(image_resolution[1])), + action_horizon=int(payload["action_horizon"]), + action_keys=tuple(str(k) for k in payload["action_keys"]), + embodiment_tag=str(payload["embodiment_tag"]), + needs_session_id=bool(payload["needs_session_id"]), + ) + + +class OpenPIWebsocketClient: + def __init__( + self, + *, + host: str = DEFAULT_HOST, + port: int = DEFAULT_PORT, + path: str = DEFAULT_PATH, + ) -> None: + self._uri = f"ws://{host}:{port}{path}" + self._packer = msgpack_numpy.Packer() + self._ws, self._server_metadata = self._connect() + + def _connect(self): + logger.info("Connecting to %s", self._uri) + conn = websockets.sync.client.connect( + self._uri, + compression=None, + max_size=None, + ping_interval=PING_INTERVAL_SECS, + ping_timeout=PING_TIMEOUT_SECS, + ) + metadata = msgpack_numpy.unpackb(conn.recv()) + if not isinstance(metadata, dict): + raise TypeError(f"Expected dict metadata from server, got {type(metadata)!r}") + return conn, metadata + + def get_server_metadata(self) -> dict[str, Any]: + return dict(self._server_metadata) + + def infer(self, obs: dict[str, Any]) -> dict[str, np.ndarray]: + payload = dict(obs) + payload["endpoint"] = "infer" + self._ws.send(self._packer.pack(payload)) + response = self._ws.recv() + if isinstance(response, str): + raise RuntimeError(f"Inference failed: {response}") + decoded = msgpack_numpy.unpackb(response) + if isinstance(decoded, dict) and decoded.get("type") == "error": + raise RuntimeError(f"GR00T server inference failed: {decoded.get('message')!r}") + if not isinstance(decoded, dict): + raise TypeError(f"Expected dict actions from GR00T server, got {type(decoded)!r}") + return {str(key): np.asarray(value, dtype=np.float32) for key, value in decoded.items()} + + def reset(self, reset_info: dict[str, Any] | None = None) -> str: + payload = dict(reset_info or {}) + payload["endpoint"] = "reset" + self._ws.send(self._packer.pack(payload)) + response = self._ws.recv() + if isinstance(response, str): + return response + decoded = msgpack_numpy.unpackb(response) + if not isinstance(decoded, dict) or decoded.get("status") != "reset successful": + raise RuntimeError(f"Unexpected reset response: {decoded!r}") + return str(decoded["status"]) + + def close(self) -> None: + self._ws.close() + + +def make_synthetic_frame(height: int = IMAGE_HEIGHT, width: int = IMAGE_WIDTH, *, seed: int = 0) -> np.ndarray: + """Produce a deterministic, non-constant RGB frame. + + Constant frames have triggered SVD failures in the Qwen3-VL backbone in + practice. We blend a horizontal and vertical gradient so the image has + real spatial variance while staying reproducible. + """ + + rng = np.random.default_rng(seed) + y_grad = np.linspace(0, 255, height, dtype=np.float32)[:, None] + x_grad = np.linspace(0, 255, width, dtype=np.float32)[None, :] + base = (0.5 * y_grad + 0.5 * x_grad).astype(np.float32) + frame = np.stack([base, np.flipud(base), base.T[:height, :width] if base.shape[0] == width else base], axis=-1) + if frame.shape != (height, width, 3): + # Fall back to a tiled gradient if the transpose path was incompatible. + frame = np.stack([base, 255.0 - base, (base + 64.0) % 256.0], axis=-1) + frame = frame + rng.uniform(-4.0, 4.0, size=frame.shape).astype(np.float32) + return np.clip(frame, 0, 255).astype(np.uint8) + + +def make_synthetic_observation(*, prompt: str, session_id: str) -> dict[str, Any]: + """Build a single GR00T observation with two-frame video history.""" + + exo_t0 = make_synthetic_frame(seed=1) + exo_t1 = make_synthetic_frame(seed=2) + wrist_t0 = make_synthetic_frame(seed=3) + wrist_t1 = make_synthetic_frame(seed=4) + + video = { + "exterior_image_1_left": np.stack([exo_t0, exo_t1])[None, ...], + "wrist_image_left": np.stack([wrist_t0, wrist_t1])[None, ...], + } + # eef_9d = xyz (3) + rot6d (6). A zero rot6d is rank-deficient and causes + # Rotation.from_matrix() to fail with "SVD did not converge". Use an identity + # pose instead: xyz=[0,0,0], rot6d=[1,0,0,0,1,0] (first two cols of I). + eef_9d_identity = np.array([0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0], dtype=np.float32) + state = { + "eef_9d": eef_9d_identity.reshape(1, 1, 9), + "gripper_position": np.zeros((1, 1, 1), dtype=np.float32), + "joint_position": np.zeros((1, 1, 7), dtype=np.float32), + } + language = {LANGUAGE_KEY: [[prompt]]} + + return { + "session_id": session_id, + "video": video, + "state": state, + "language": language, + } + + +def validate_actions( + actions: dict[str, np.ndarray], + *, + expected_action_horizon: int, +) -> None: + missing = [k for k in EXPECTED_ACTION_KEYS if k not in actions] + if missing: + raise AssertionError(f"Missing action keys from server response: {missing}") + + expected_dims = {"eef_9d": 9, "gripper_position": 1, "joint_position": 7} + for key, last_dim in expected_dims.items(): + action = actions[key] + if action.ndim != 3: + raise AssertionError(f"Action {key!r} must be 3D, got shape {action.shape}") + if action.shape[1] != expected_action_horizon: + raise AssertionError( + f"Action {key!r} horizon mismatch: expected {expected_action_horizon}, got {action.shape[1]}" + ) + if action.shape[-1] != last_dim: + raise AssertionError(f"Action {key!r} trailing dim mismatch: expected {last_dim}, got {action.shape[-1]}") + if not np.isfinite(action).all(): + raise AssertionError(f"Action {key!r} contains non-finite values") + + +def run_policy_session( + *, + host: str = DEFAULT_HOST, + port: int = DEFAULT_PORT, + path: str = DEFAULT_PATH, + prompt: str = DEFAULT_PROMPT, + session_id: str | None = None, +) -> dict[str, Any]: + session_id = session_id or str(uuid.uuid4()) + observation = make_synthetic_observation(prompt=prompt, session_id=session_id) + + client = OpenPIWebsocketClient(host=host, port=port, path=path) + try: + metadata = client.get_server_metadata() + actions = client.infer(observation) + reset_status = client.reset({"session_id": session_id}) + return { + "metadata": metadata, + "actions": actions, + "reset_status": reset_status, + "session_id": session_id, + } + finally: + client.close() + + +def format_action_summary(key: str, action: np.ndarray) -> str: + return ( + f"Action {key!r}: shape={tuple(action.shape)} dtype={action.dtype} " + f"min={float(action.min()):.6f} max={float(action.max()):.6f}" + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="GR00T-N1.7 OpenPI websocket client demo.") + parser.add_argument("--host", default=DEFAULT_HOST) + parser.add_argument("--port", type=int, default=DEFAULT_PORT) + parser.add_argument("--path", default=DEFAULT_PATH) + parser.add_argument("--prompt", default=DEFAULT_PROMPT) + parser.add_argument("--session-id", default=None) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + result = run_policy_session( + host=args.host, + port=args.port, + path=args.path, + prompt=args.prompt, + session_id=args.session_id, + ) + + server_metadata = Gr00tServerMetadata.from_dict(result["metadata"]) + validate_actions(result["actions"], expected_action_horizon=server_metadata.action_horizon) + + print(f"Server embodiment: {server_metadata.embodiment_tag}") + print(f"Server image_resolution: {server_metadata.image_resolution}") + print(f"Server action_horizon: {server_metadata.action_horizon}") + print(f"Server action_keys: {server_metadata.action_keys}") + for key in sorted(result["actions"]): + print(format_action_summary(key, result["actions"][key])) + print(f"Reset status: {result['reset_status']}") + print(f"Session ID: {result['session_id']}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/online_serving/gr00t/run_server.sh b/examples/online_serving/gr00t/run_server.sh new file mode 100755 index 00000000000..fc4ca79b7a8 --- /dev/null +++ b/examples/online_serving/gr00t/run_server.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +set -euo pipefail + +MODEL="${MODEL:-nvidia/GR00T-N1.7-3B}" +HOST="${HOST:-127.0.0.1}" +PORT="${PORT:-8000}" +DEPLOY_CONFIG="${DEPLOY_CONFIG:-vllm_omni/deploy/Gr00tN1d7.yaml}" +SERVED_MODEL_NAME="${SERVED_MODEL_NAME:-gr00t-n1d7}" + +args=( + serve + "$MODEL" + --omni + --host "$HOST" + --port "$PORT" + --served-model-name "$SERVED_MODEL_NAME" + --stage-configs-path "$DEPLOY_CONFIG" + --disable-log-stats +) + +CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}" \ +VLLM_WORKER_MULTIPROC_METHOD="${VLLM_WORKER_MULTIPROC_METHOD:-spawn}" \ +vllm "${args[@]}" From bb2964a65bd012af2b4506e381f3ed39356b3a06 Mon Sep 17 00:00:00 2001 From: Zhengyuan Su Date: Sun, 24 May 2026 19:00:21 +0700 Subject: [PATCH 3/8] remove the legacy cache_position argument Signed-off-by: Zhengyuan Su --- vllm_omni/diffusion/models/internvla_a1/adapter_qwen3_vl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_omni/diffusion/models/internvla_a1/adapter_qwen3_vl.py b/vllm_omni/diffusion/models/internvla_a1/adapter_qwen3_vl.py index 06cc8bd5195..40a4a111b6b 100644 --- a/vllm_omni/diffusion/models/internvla_a1/adapter_qwen3_vl.py +++ b/vllm_omni/diffusion/models/internvla_a1/adapter_qwen3_vl.py @@ -172,7 +172,6 @@ def forward( config=self.config, inputs_embeds=inputs_embeds, attention_mask=attention_mask, - cache_position=cache_position, past_key_values=past_key_values, position_ids=text_position_ids, ) From 9f5e89afb44a88c62e22f9967b05d8f07020da9c Mon Sep 17 00:00:00 2001 From: Zhengyuan Su Date: Mon, 1 Jun 2026 14:28:02 +0700 Subject: [PATCH 4/8] Trim training-related code. Signed-off-by: Zhengyuan Su --- requirements/common.txt | 1 - .../models/gr00t/configs/model/gr00t_n1d7.py | 62 +- .../models/gr00t/dataio/interfaces.py | 40 +- .../state_action/state_action_processor.py | 19 +- .../models/gr00t/modeling/gr00t_n1d7.py | 169 ------ .../gr00t/modeling/image_augmentations.py | 564 ------------------ .../gr00t/modeling/processing_gr00t_n1d7.py | 197 ++---- vllm_omni/diffusion/models/gr00t/policy.py | 21 +- 8 files changed, 65 insertions(+), 1008 deletions(-) delete mode 100755 vllm_omni/diffusion/models/gr00t/modeling/image_augmentations.py diff --git a/requirements/common.txt b/requirements/common.txt index 1af2eaef55a..e649450afdb 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -10,7 +10,6 @@ omegaconf>=2.3.0 diffusers==0.38.0 safetensors>=0.8.0rc0 accelerate==1.12.0 -albumentations==1.4.18 soundfile>=0.13.1 cache-dit==1.3.0 tqdm>=4.66.0 diff --git a/vllm_omni/diffusion/models/gr00t/configs/model/gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/configs/model/gr00t_n1d7.py index 8b67958d330..9eb3ee84945 100644 --- a/vllm_omni/diffusion/models/gr00t/configs/model/gr00t_n1d7.py +++ b/vllm_omni/diffusion/models/gr00t/configs/model/gr00t_n1d7.py @@ -13,12 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -from dataclasses import MISSING, asdict, dataclass, is_dataclass -from enum import Enum -from pathlib import Path +from dataclasses import MISSING, dataclass -import torch from transformers import PretrainedConfig from . import register_model_config @@ -31,8 +27,6 @@ def _default_diffusion_model_cfg() -> dict: "num_attention_heads": 32, "attention_head_dim": 48, "norm_type": "ada_norm", - "dropout": 0.2, - "final_dropout": True, "output_dim": 1024, "interleave_self_attention": True, } @@ -62,7 +56,6 @@ class Gr00tN1d7Config(PretrainedConfig): reproject_vision: bool = False use_flash_attention: bool = True load_bf16: bool = False # Enable BF16 loading - backbone_trainable_params_fp32: bool = True ### Processing parameters image_crop_size: tuple[int, int] | None = (230, 230) @@ -73,9 +66,6 @@ class Gr00tN1d7Config(PretrainedConfig): random_rotation_angle: int | None = None color_jitter_params: dict[str, float] | None = None - use_albumentations_transforms: bool = True - # Extra augmentation config (mask-based and others). - extra_augmentation_config: dict | None = None formalize_language: bool = True apply_sincos_state_encoding: bool = False # Global flag to enable per-embodiment sin/cos encoding use_percentiles: bool = True @@ -103,18 +93,9 @@ class Gr00tN1d7Config(PretrainedConfig): # Flow matching parameters num_inference_timesteps: int = 4 - noise_beta_alpha: float = 1.5 - noise_beta_beta: float = 1.0 - noise_s: float = 0.999 num_timestep_buckets: int = 1000 - # Training parameters - tune_projector: bool = True - tune_diffusion_model: bool = True - tune_vlln: bool = True - - # State augmentation parameters - state_dropout_prob: float = 0.8 # State dropout probability + # State augmentation parameters (inference-relevant only) exclude_state: bool = False # Zero out all state inputs (ablation) use_mean_std: bool = False # Use mean/std normalization instead of min/max @@ -141,44 +122,5 @@ def __init__(self, **kwargs): else: self.diffusion_model_cfg = dict(self.diffusion_model_cfg) - def to_filtered_dict(self, exclude_augment: bool = True) -> dict: - """Return a dictionary representation of this config, optionally excluding augmentation keys.""" - if is_dataclass(self): - cfg = asdict(self) - else: - cfg = dict(self.__dict__) - - if exclude_augment: - exclude_keys = { - "random_rotation_angle", - "color_jitter_params", - "use_albumentations_transforms", - "formalize_language", - "image_crop_size", - "image_target_size", - "shortest_image_edge", - "crop_fraction", - } - cfg = {k: v for k, v in cfg.items() if k not in exclude_keys} - - return cfg - - def to_filtered_json(self, exclude_augment: bool = True, **kwargs) -> str: - """Return a JSON string of this config, optionally excluding augmentation keys.""" - - def default(o): - if isinstance(o, (Path, torch.dtype, torch.device)): - return str(o) - if isinstance(o, Enum): - return o.value - return str(o) - - return json.dumps( - self.to_filtered_dict(exclude_augment), - indent=2, - default=default, - **kwargs, - ) - register_model_config("Gr00tN1d7", Gr00tN1d7Config) diff --git a/vllm_omni/diffusion/models/gr00t/dataio/interfaces.py b/vllm_omni/diffusion/models/gr00t/dataio/interfaces.py index ed8e2450004..f2d263600f4 100644 --- a/vllm_omni/diffusion/models/gr00t/dataio/interfaces.py +++ b/vllm_omni/diffusion/models/gr00t/dataio/interfaces.py @@ -68,12 +68,6 @@ def set_statistics(self, statistics: dict[str, Any], override: bool = False) -> """Set normalization statistics.""" pass - def train(self): - self.training = True - - def eval(self): - self.training = False - def get_modality_configs(self) -> dict[str, dict[str, ModalityConfig]]: """Get the modality configurations. @@ -107,37 +101,5 @@ def set_processor(self, processor: BaseProcessor): self.processor = processor def get_dataset_statistics(self) -> dict[str, Any]: - """Get the dataset statistics. This is only required for dataloaders for robtics datasets.""" + """Get dataset statistics.""" raise NotImplementedError() - - -# # Example chat formats (processor input) -# # Single step -# messages = [ -# {"type": "episode_step", "content": VLAStepData}, -# ] -# # Single episode -# messages = [ -# {"type": MessageType.START_OF_EPISODE.value, "content": ""}, -# {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, -# {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, -# {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, -# {"type": MessageType.END_OF_EPISODE.value, "content": ""}, -# ] -# # Multiple episodes -# messages = [ -# {"type": MessageType.START_OF_EPISODE.value, "content": ""}, -# {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, -# {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, -# {"type": MessageType.END_OF_EPISODE.value, "content": ""}, -# {"type": MessageType.START_OF_EPISODE.value, "content": ""}, -# {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, -# {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, -# {"type": MessageType.END_OF_EPISODE.value, "content": ""}, -# ] - -# # Example usage -# messages = dataset[idx] -# model_input = processor(messages) -# model_output = model(**model_input) # or model.generate(**model_input) -# decoded_action = processor.decode_action(model_output) diff --git a/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py b/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py index 6126c5610cf..3c861c0c259 100644 --- a/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py +++ b/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py @@ -13,15 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Unified processor for robot state and action data. - -Handles: -- State normalization (min/max, mean/std, sin/cos encoding) -- Action normalization -- Absolute <-> Relative action representation conversion -- Action processing with state dependency -""" +"""Unified state and action processor for robotics.""" from copy import deepcopy @@ -101,14 +93,6 @@ def __init__( if statistics is not None: self.set_statistics(statistics) - self.train() - - def train(self): - self.training = True - - def eval(self): - self.training = False - def set_statistics( self, statistics: dict[str, dict[str, dict[str, dict[str, list[float]]]]], @@ -522,7 +506,6 @@ def apply( if action: processed_action = self.apply_action(action, embodiment_tag, state=state) else: - assert not self.training, "Action is required in training mode" processed_action = {} return processed_state, processed_action diff --git a/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py index 4a84d6eceb5..4ed48841022 100644 --- a/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py +++ b/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py @@ -17,9 +17,7 @@ from typing import Any import torch -import torch.nn.functional as F from torch import nn -from torch.distributions import Beta from transformers import AutoConfig, AutoModel, PreTrainedModel from transformers.feature_extraction_utils import BatchFeature from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig @@ -146,64 +144,7 @@ def __init__(self, config: Gr00tN1d7Config): self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim) nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) - # State dropout parameters - self.state_dropout_prob = config.state_dropout_prob - - self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta) self.num_timestep_buckets = config.num_timestep_buckets - self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model, config.tune_vlln) - - def set_trainable_parameters(self, tune_projector: bool, tune_diffusion_model: bool, tune_vlln: bool): - self.tune_projector = tune_projector - self.tune_diffusion_model = tune_diffusion_model - self.tune_vlln = tune_vlln - for p in self.parameters(): - p.requires_grad = True - if not tune_projector: - self.state_encoder.requires_grad_(False) - self.action_encoder.requires_grad_(False) - self.action_decoder.requires_grad_(False) - if self.config.add_pos_embed: - self.position_embedding.requires_grad_(False) - if not tune_diffusion_model: - self.model.requires_grad_(False) - if not tune_vlln: - self.vlln.requires_grad_(False) - self.vl_self_attention.requires_grad_(False) - logger.debug(f"Tune action head projector: {self.tune_projector}") - logger.debug(f"Tune action head diffusion model: {self.tune_diffusion_model}") - logger.debug(f"Tune action head vlln: {self.tune_vlln}") - # Check if any parameters are still trainable. If not, log a warning. - if not tune_projector and not tune_diffusion_model and not tune_vlln: - for name, p in self.named_parameters(): - if p.requires_grad: - logger.debug(f"Action head trainable parameter: {name}") - if not any(p.requires_grad for p in self.parameters()): - logger.warning("No action head trainable parameters found.") - - def set_frozen_modules_to_eval_mode(self): - """ - Huggingface will call model.train() at each training_step. To ensure - the expected behaviors for modules like dropout, batchnorm, etc., we - need to call model.eval() for the frozen modules. - """ - if self.training: - if not self.tune_projector: - self.state_encoder.eval() - self.action_encoder.eval() - self.action_decoder.eval() - if self.config.add_pos_embed: - self.position_embedding.eval() - if not self.tune_diffusion_model: - self.model.eval() - if not self.tune_vlln: - self.vlln.eval() - self.vl_self_attention.eval() - - def sample_time(self, batch_size, device, dtype): - sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype) - sample = (1 - sample) * self.config.noise_s - return sample def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature: backbone_features = backbone_output["backbone_features"] @@ -212,109 +153,6 @@ def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature backbone_output["backbone_features"] = backbone_features return backbone_output - def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature: - """ - Forward pass through the action head. - - Args: - backbone_output: Output from the backbone model containing: - - backbone_features: [B, seq_len, backbone_embedding_dim] - - backbone_attention_mask: [B, seq_len] - action_input: Input containing: - - state: [B, state_dim] - - action: [B, action_horizon, action_dim] (during training) - - embodiment_id: [B] (embodiment IDs) - - action_mask: [B, action_horizon, action_dim] - - Returns: - BatchFeature containing: - - loss: action prediction loss - """ - # Set frozen modules to eval - self.set_frozen_modules_to_eval_mode() - - backbone_output = self.process_backbone_output(backbone_output) - - # Get vision and language embeddings. - vl_embeds = backbone_output.backbone_features - device = vl_embeds.device - - # Get embodiment ID. - embodiment_id = action_input.embodiment_id - - # Handle state history - assert action_input.state.shape[1] == self.config.state_history_length - action_input.state = action_input.state.view(action_input.state.shape[0], 1, -1) - - # Embed state. - state_features = self.state_encoder(action_input.state, embodiment_id) - - # Dropout state features (training only): zero out dropped states. - if self.training and self.state_dropout_prob > 0: - do_dropout = torch.rand(state_features.shape[0], device=state_features.device) < self.state_dropout_prob - do_dropout = do_dropout[:, None, None].to(dtype=state_features.dtype) - state_features = state_features * (1 - do_dropout) - - # Embed noised action trajectory. - actions = action_input.action - noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype) - t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype) - t = t[:, None, None] # shape (B,1,1) for broadcast - - noisy_trajectory = (1 - t) * noise + t * actions - velocity = actions - noise - - # Convert (continuous) t -> discrete if needed - t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long() - action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id) - - # Maybe add position embedding. - if self.config.add_pos_embed: - pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) - pos_embs = self.position_embedding(pos_ids).unsqueeze(0) - action_features = action_features + pos_embs - - # Join vision, language, state and action embedding along sequence dimension. - sa_embs = torch.cat((state_features, action_features), dim=1) - vl_attn_mask = backbone_output.backbone_attention_mask - - if self.config.use_alternate_vl_dit: - image_mask = backbone_output.image_mask - backbone_attention_mask = backbone_output.backbone_attention_mask - model_output, _ = self.model( - hidden_states=sa_embs, - encoder_hidden_states=vl_embeds, - encoder_attention_mask=vl_attn_mask, - timestep=t_discretized, - return_all_hidden_states=True, - image_mask=image_mask, - backbone_attention_mask=backbone_attention_mask, - ) - else: - model_output, _ = self.model( - hidden_states=sa_embs, - encoder_hidden_states=vl_embeds, - encoder_attention_mask=vl_attn_mask, - timestep=t_discretized, - return_all_hidden_states=True, - ) - - pred = self.action_decoder(model_output, embodiment_id) - pred_actions = pred[:, -actions.shape[1] :] - - # Slice out only the action portion of pred and target. - action_mask = action_input.action_mask - action_loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask - loss = action_loss.sum() / (action_mask.sum() + 1e-6) - - return { - "loss": loss, - "action_loss": action_loss, - "action_mask": action_mask, - "backbone_features": vl_embeds, - "state_features": state_features, - } - def _encode_features(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature: """ Encode features for the action head. @@ -523,7 +361,6 @@ def __init__( backbone_embedding_dim: int, load_bf16: bool, tune_top_llm_layers: int, - trainable_params_fp32: bool, transformers_loading_kwargs: dict[str, Any] | None = None, ): super().__init__() @@ -558,11 +395,6 @@ def __init__( self.model.model.language_model.layers.pop(-1) self.set_trainable_parameters(tune_llm, tune_visual, tune_top_llm_layers) - if load_bf16 and trainable_params_fp32: - for name, param in self.named_parameters(): - if param.requires_grad: - param.data = param.data.to(torch.float32) - logger.debug("Casting trainable parameter %s to fp32", name) def set_trainable_parameters(self, tune_llm: bool, tune_visual: bool, tune_top_llm_layers: int) -> None: self.tune_llm = tune_llm @@ -692,7 +524,6 @@ def __init__( backbone_embedding_dim=config.backbone_embedding_dim, load_bf16=config.load_bf16, tune_top_llm_layers=config.tune_top_llm_layers, - trainable_params_fp32=config.backbone_trainable_params_fp32, transformers_loading_kwargs=transformers_loading_kwargs, ) diff --git a/vllm_omni/diffusion/models/gr00t/modeling/image_augmentations.py b/vllm_omni/diffusion/models/gr00t/modeling/image_augmentations.py deleted file mode 100755 index f4908564c00..00000000000 --- a/vllm_omni/diffusion/models/gr00t/modeling/image_augmentations.py +++ /dev/null @@ -1,564 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import warnings -from collections.abc import Sequence - -import albumentations as A -import cv2 -import numpy as np -import torch -import torchvision.transforms.v2 as transforms - - -def apply_with_replay(transform, images, masks=None, replay=None): - """ - Apply albumentations transforms to multiple images with replay functionality. - When masks are provided, mask-based transforms run per-frame before the main transform. - - Args: - transform: Albumentations ReplayCompose or Compose transform - images: List of PIL Images to transform - masks: Optional list of masks aligned with images (H, W) - replay: Optional replay data for consistent transforms. If None, creates new replay. - - Returns: - tuple: (transformed_tensors_list, replay_data) - - transformed_tensors_list: List of transformed torch tensors (C, H, W) as uint8 - - replay_data: Replay data for consistent transforms across images (None for regular Compose) - """ - transformed_tensors = [] - current_replay = replay - - # Check if transform supports replay (ReplayCompose) - has_replay = hasattr(transform, "replay") - - # Get mask-based transforms (applied per-frame, not replayed) - mask_transforms = getattr(transform, "mask_transforms", None) - - if masks is not None and len(masks) != len(images): - raise ValueError(f"Number of masks ({len(masks)}) must match number of images ({len(images)})") - - for idx, img in enumerate(images): - img_array = np.array(img) - mask_array = None if masks is None else np.array(masks[idx]) - if mask_array is not None and mask_array.dtype == np.bool_: - mask_array = mask_array.astype(np.uint8) - - # Apply mask-based transforms FIRST (per-frame, using current frame's mask) - if mask_transforms and mask_array is not None: - for mask_tf in mask_transforms: - result = mask_tf(image=img_array, mask=mask_array) - img_array = result["image"] - - if has_replay: - if current_replay is None: - # First image - create replay data - augmented_image = transform(image=img_array) - current_replay = augmented_image["replay"] - else: - # Subsequent images - use replay for consistent transforms - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning) - augmented_image = transform.replay(image=img_array, saved_augmentations=current_replay) - else: - # Regular Compose transform - no replay functionality - augmented_image = transform(image=img_array) - - img_array = augmented_image["image"] - # Convert to uint8 if needed (albumentations may return float32 in [0,1]) - if img_array.dtype == np.float32: - img_array = (img_array * 255).astype(np.uint8) - elif img_array.dtype != np.uint8: - raise ValueError(f"Unexpected data type: {img_array.dtype}") - - # Convert to torch tensor (C, H, W) as uint8 - img_tensor = torch.from_numpy(img_array).permute(2, 0, 1) - transformed_tensors.append(img_tensor) - - return transformed_tensors, current_replay - - -class MaskedColorTransform(A.ImageOnlyTransform): - """Apply random tint to specific mask regions. - - Args: - target_mask_values: List of mask values to apply the transform to - alpha_range: (min, max) for random_tint overlay intensity - p: Probability of applying the transform - """ - - def __init__( - self, - target_mask_values: Sequence[int], - alpha_range: tuple[float, float] = (0.3, 1.0), - p: float = 0.5, - always_apply: bool | None = None, - ): - super().__init__(p=p, always_apply=always_apply) - self.target_mask_values = list(target_mask_values) - self.alpha_range = alpha_range - - def apply(self, img: np.ndarray, mask: np.ndarray = None, **params) -> np.ndarray: - if mask is None: - return img - - region_mask = np.zeros(mask.shape[:2], dtype=bool) - for val in self.target_mask_values: - region_mask |= mask == val - - if not region_mask.any(): - return img - - # Random color - random_color = np.random.randint(0, 256, size=3).astype(np.float32) - result = img.copy().astype(np.float32) - - # Random tint: semi-transparent overlay - alpha = np.random.uniform(self.alpha_range[0], self.alpha_range[1]) - for c in range(3): - result[region_mask, c] = result[region_mask, c] * (1 - alpha) + random_color[c] * alpha - - return np.clip(result, 0, 255).astype(np.uint8) - - def get_params_dependent_on_data(self, params, data) -> dict: - return {"mask": data.get("mask")} - - def get_transform_init_args_names(self) -> tuple[str, ...]: - return ("target_mask_values", "alpha_range") - - -class BackgroundNoiseTransform(A.ImageOnlyTransform): - """Replace specified mask regions with random noise. - - This transform replaces pixels where mask value matches target_mask_values with random RGB noise, - useful for domain randomization in sim-to-real transfer. - - Args: - p: Probability of applying the transform - target_mask_values: Mask values to replace with noise (default: [0]) - """ - - def __init__( - self, - p: float = 1.0, - target_mask_values: Sequence[int] | None = None, - always_apply: bool | None = None, - ): - super().__init__(p=p, always_apply=always_apply) - self.target_mask_values = [0] if target_mask_values is None else list(target_mask_values) - - def apply(self, img: np.ndarray, mask: np.ndarray = None, **params) -> np.ndarray: - if mask is None: - return img - - result = img.copy() - mask_2d = mask[..., 0] if mask.ndim == 3 else mask - background = np.isin(mask_2d, self.target_mask_values) - - if background.any(): - noise = np.random.randint(0, 256, size=result.shape, dtype=np.uint8) - result[background] = noise[background] - - return result - - def get_params_dependent_on_data(self, params, data) -> dict: - return {"mask": data.get("mask")} - - def get_transform_init_args_names(self) -> tuple[str, ...]: - return ("target_mask_values",) - - -class FractionalRandomCrop(A.DualTransform): - """Crop a random part of the input based on fractions while maintaining aspect ratio. - - Args: - crop_fraction: Fraction of the image to crop (0.0 to 1.0). The crop will maintain - the original aspect ratio and be this fraction of the original area. - p: probability of applying the transform. Default: 1.0 - - Targets: - image, mask, bboxes, keypoints - - Image types: - uint8, float32 - """ - - def __init__( - self, - crop_fraction: float = 0.9, - p: float = 1.0, - always_apply: bool | None = None, - ): - super().__init__(p=p, always_apply=always_apply) - if not 0.0 < crop_fraction <= 1.0: - raise ValueError("crop_fraction must be between 0.0 and 1.0") - self.crop_fraction = crop_fraction - - def apply(self, img: np.ndarray, crop_coords: tuple[int, int, int, int], **params) -> np.ndarray: - x_min, y_min, x_max, y_max = crop_coords - return img[y_min:y_max, x_min:x_max] - - def apply_to_bboxes(self, bboxes: np.ndarray, crop_coords: tuple[int, int, int, int], **params) -> np.ndarray: - return A.augmentations.crops.functional.crop_bboxes_by_coords(bboxes, crop_coords, params["shape"]) - - def apply_to_keypoints(self, keypoints: np.ndarray, crop_coords: tuple[int, int, int, int], **params) -> np.ndarray: - return A.augmentations.crops.functional.crop_keypoints_by_coords(keypoints, crop_coords) - - def get_params_dependent_on_data(self, params, data) -> dict[str, tuple[int, int, int, int]]: - image_shape = params["shape"][:2] - height, width = image_shape - - # Calculate crop dimensions with linear scaling - crop_height = int(height * self.crop_fraction) - crop_width = int(width * self.crop_fraction) - - # Ensure minimum size of 1x1 - crop_height = max(1, crop_height) - crop_width = max(1, crop_width) - # Random position for crop - max_y = height - crop_height - max_x = width - crop_width - - y_min = np.random.randint(0, max_y + 1) if max_y > 0 else 0 - x_min = np.random.randint(0, max_x + 1) if max_x > 0 else 0 - - crop_coords = (x_min, y_min, x_min + crop_width, y_min + crop_height) - return {"crop_coords": crop_coords} - - def get_transform_init_args_names(self) -> tuple[str, ...]: - return ("crop_fraction",) - - -class FractionalCenterCrop(A.DualTransform): - """Crop the center part of the input based on fractions while maintaining aspect ratio. - - Args: - crop_fraction: Fraction of the image to crop (0.0 to 1.0). The crop will maintain - the original aspect ratio and be this fraction of the original area. - p: probability of applying the transform. Default: 1.0 - - Targets: - image, mask, bboxes, keypoints - - Image types: - uint8, float32 - """ - - def __init__( - self, - crop_fraction: float = 0.9, - p: float = 1.0, - always_apply: bool | None = None, - ): - super().__init__(p=p, always_apply=always_apply) - if not 0.0 < crop_fraction <= 1.0: - raise ValueError("crop_fraction must be between 0.0 and 1.0") - self.crop_fraction = crop_fraction - - def apply(self, img: np.ndarray, crop_coords: tuple[int, int, int, int], **params) -> np.ndarray: - x_min, y_min, x_max, y_max = crop_coords - return img[y_min:y_max, x_min:x_max] - - def apply_to_bboxes(self, bboxes: np.ndarray, crop_coords: tuple[int, int, int, int], **params) -> np.ndarray: - return A.augmentations.crops.functional.crop_bboxes_by_coords(bboxes, crop_coords, params["shape"]) - - def apply_to_keypoints(self, keypoints: np.ndarray, crop_coords: tuple[int, int, int, int], **params) -> np.ndarray: - return A.augmentations.crops.functional.crop_keypoints_by_coords(keypoints, crop_coords) - - def get_params_dependent_on_data(self, params, data) -> dict[str, tuple[int, int, int, int]]: - image_shape = params["shape"][:2] - height, width = image_shape - - # Calculate crop dimensions with linear scaling - crop_height = int(height * self.crop_fraction) - crop_width = int(width * self.crop_fraction) - - # Ensure minimum size of 1x1 - crop_height = max(1, crop_height) - crop_width = max(1, crop_width) - - # Center the crop - y_min = (height - crop_height) // 2 - x_min = (width - crop_width) // 2 - - crop_coords = (x_min, y_min, x_min + crop_width, y_min + crop_height) - return {"crop_coords": crop_coords} - - def get_transform_init_args_names(self) -> tuple[str, ...]: - return ("crop_fraction",) - - -class LetterBoxPad(A.DualTransform): - """Pad non-square images to square by adding black bars (letterboxing). - - This is the albumentations equivalent of LetterBoxTransform (torchvision). - Ensures all images have the same spatial dimensions after padding, - regardless of their original aspect ratio. - - Targets: - image - - Image types: - uint8, float32 - """ - - def __init__(self, p: float = 1.0, always_apply: bool | None = None): - super().__init__(p=p, always_apply=always_apply) - - def apply( - self, - img: np.ndarray, - pad_top: int = 0, - pad_bottom: int = 0, - pad_left: int = 0, - pad_right: int = 0, - **params, - ) -> np.ndarray: - if pad_top == 0 and pad_bottom == 0 and pad_left == 0 and pad_right == 0: - return img - return cv2.copyMakeBorder(img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0) - - def get_params_dependent_on_data(self, params, data) -> dict[str, int]: - h, w = params["shape"][:2] - if h == w: - return {"pad_top": 0, "pad_bottom": 0, "pad_left": 0, "pad_right": 0} - max_dim = max(h, w) - pad_h = max_dim - h - pad_w = max_dim - w - return { - "pad_top": pad_h // 2, - "pad_bottom": pad_h - pad_h // 2, - "pad_left": pad_w // 2, - "pad_right": pad_w - pad_w // 2, - } - - def get_transform_init_args_names(self) -> tuple[str, ...]: - return () - - -def build_image_transformations_albumentations( - image_target_size, - image_crop_size, - random_rotation_angle, - color_jitter_params, - shortest_image_edge, - crop_fraction, - extra_augmentation_config: dict | None = None, -): - """ - Build albumentations-based image transformations equivalent to the torchvision version. - - Args: - image_target_size: Target size for resizing (list of [height, width]) - image_crop_size: Size for cropping (list of [height, width]) - random_rotation_angle: Maximum rotation angle in degrees (0 for no rotation) - color_jitter_params: Dictionary with color jitter parameters (brightness, contrast, saturation, hue) - shortest_image_edge: Shortest edge size for resizing - crop_fraction: Fraction of image to crop - extra_augmentation_config: Optional dict for additional augmentations. Supported keys: - - "background_noise_transforms": list of dicts, each with: - - "target_mask_values": list of int (e.g., [0]) - - "p": float (probability of applying transform) - - "masked_region_transforms": list of dicts, each with: - - "target_mask_values": list of int (e.g., [4] or [5]) - - "p": float (probability of applying transform) - - "alpha_range": [min, max] for random_tint mode intensity - - Returns: - tuple: (train_transform, eval_transform) - raw albumentations transforms - """ - - if crop_fraction is None: - fraction_to_use = image_crop_size[0] / image_target_size[0] - else: - fraction_to_use = crop_fraction - - if shortest_image_edge is None: - max_size = image_target_size[0] - else: - max_size = shortest_image_edge - - extra_augmentation_config = extra_augmentation_config or {} - - # Training transforms (using ReplayCompose for consistent augmentation across views) - # Use SmallestMaxSize to preserve aspect ratios, with INTER_AREA for antialiasing - train_transform_list = [ - LetterBoxPad(), - A.SmallestMaxSize(max_size=max_size, interpolation=cv2.INTER_AREA), - FractionalRandomCrop(crop_fraction=fraction_to_use), - A.SmallestMaxSize(max_size=max_size, interpolation=cv2.INTER_AREA), - ] - - if random_rotation_angle is not None and random_rotation_angle != 0: - train_transform_list.append(A.Rotate(limit=random_rotation_angle, p=1.0)) - - if color_jitter_params is not None: - # Map torchvision ColorJitter parameters to albumentations ColorJitter - # Note: albumentations uses different parameter names and ranges - train_transform_list.append( - A.ColorJitter( - brightness=color_jitter_params.get("brightness", 0.0), - contrast=color_jitter_params.get("contrast", 0.0), - saturation=color_jitter_params.get("saturation", 0.0), - hue=color_jitter_params.get("hue", 0.0), - p=1.0, - ) - ) - - train_transform = A.ReplayCompose(train_transform_list, p=1.0) - - # === Mask-based augmentations (applied per-frame, NOT in ReplayCompose) === - # These transforms depend on per-frame mask data and must not be replayed - # to ensure each frame uses its own mask - mask_transforms = [] - - # Background noise on mask regions - for noise_cfg in extra_augmentation_config.get("background_noise_transforms", []): - target_mask_values = noise_cfg.get("target_mask_values", [0]) - p = noise_cfg.get("p", 1.0) - mask_transforms.append( - BackgroundNoiseTransform( - p=float(p), - target_mask_values=target_mask_values, - ) - ) - - # Masked region transforms - for transform_cfg in extra_augmentation_config.get("masked_region_transforms", []): - target_mask_values = transform_cfg.get("target_mask_values", []) - p = transform_cfg.get("p", 0.5) - alpha_range = tuple(transform_cfg.get("alpha_range", [0.3, 1.0])) - - mask_transforms.append( - MaskedColorTransform( - target_mask_values=target_mask_values, - alpha_range=alpha_range, - p=p, - ) - ) - - # Attach mask transforms to the main transform for use in apply_with_replay - train_transform.mask_transforms = mask_transforms if mask_transforms else None - - # Evaluation transforms (deterministic, no extra augmentations) - # Use SmallestMaxSize to preserve aspect ratios, with INTER_AREA for antialiasing - eval_transform = A.Compose( - [ - LetterBoxPad(), - A.SmallestMaxSize(max_size=max_size, interpolation=cv2.INTER_AREA), - FractionalCenterCrop(crop_fraction=fraction_to_use), - A.SmallestMaxSize(max_size=max_size, interpolation=cv2.INTER_AREA), - ] - ) - - return train_transform, eval_transform - - -class LetterBoxTransform: - """Custom transform to pad non-square images to square by adding black bars. - - Works with any tensor shape where the last 3 dimensions are (C, H, W). - Leading dimensions (batch, time, views, etc.) are preserved. - """ - - def __call__(self, img: torch.Tensor) -> torch.Tensor: - """ - Pad image to square dimensions by adding black bars to the smaller dimension. - - Args: - img: Image tensor of shape (..., C, H, W) where ... can be any leading dimensions - Examples: (C, H, W), (B, C, H, W), (B, T*V, C, H, W) - - Returns: - Padded image tensor of shape (..., C, max(H,W), max(H,W)) - """ - # Get the height and width from the last 2 dimensions - *leading_dims, c, h, w = img.shape - - if h == w: - return img - - # Calculate padding needed - max_dim = max(h, w) - pad_h = max_dim - h - pad_w = max_dim - w - - # Add padding to center the image (divide padding equally on both sides) - pad_top = pad_h // 2 - pad_bottom = pad_h - pad_top - pad_left = pad_w // 2 - pad_right = pad_w - pad_left - - # If we have leading dimensions, we need to flatten them, pad, then unflatten - if leading_dims: - # Reshape to (batch, C, H, W) where batch includes all leading dimensions - batch_size = math.prod(leading_dims) - img_reshaped = img.reshape(batch_size, c, h, w) - - # Apply padding to each image in the batch - # torchvision padding format: (left, right, top, bottom) - padded_img = transforms.functional.pad( - img_reshaped, padding=[pad_left, pad_top, pad_right, pad_bottom], fill=0 - ) - - # Reshape back to original leading dimensions - output_shape = leading_dims + [c, max_dim, max_dim] - padded_img = padded_img.reshape(output_shape) - else: - # Simple case: just (C, H, W) - padded_img = transforms.functional.pad(img, padding=[pad_left, pad_top, pad_right, pad_bottom], fill=0) - - return padded_img - - -def build_image_transformations(image_target_size, image_crop_size, random_rotation_angle, color_jitter_params): - """ - Build torchvision-based image transformations. - - Args: - image_target_size: Target size for resizing (list of [height, width]) - image_crop_size: Size for cropping (list of [height, width]) - random_rotation_angle: Maximum rotation angle in degrees (0 for no rotation) - color_jitter_params: Dictionary with color jitter parameters (brightness, contrast, saturation, hue) - - Returns: - tuple: (train_transform, eval_transform) - torchvision transforms - """ - transform_list = [ - transforms.ToImage(), - LetterBoxTransform(), - # transforms.ToDtype(torch.get_default_dtype(), scale=True), - transforms.Resize(size=image_target_size), - transforms.RandomCrop(size=image_crop_size), - transforms.Resize(size=image_target_size), - ] - if random_rotation_angle is not None and random_rotation_angle != 0: - transform_list.append(transforms.RandomRotation(degrees=[-random_rotation_angle, random_rotation_angle])) - if color_jitter_params is not None: - transform_list.append(transforms.ColorJitter(**color_jitter_params)) - train_image_transform = transforms.Compose(transform_list) - eval_image_transform = transforms.Compose( - [ - transforms.ToImage(), - # transforms.ToDtype(torch.get_default_dtype(), scale=True), - LetterBoxTransform(), - transforms.Resize(size=image_target_size), - transforms.CenterCrop(size=image_crop_size), - transforms.Resize(size=image_target_size), - ] - ) - return train_image_transform, eval_image_transform diff --git a/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py index b53fd3dd520..e67225fc4a2 100644 --- a/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py +++ b/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py @@ -15,14 +15,12 @@ import json import os -import random import re import warnings from copy import deepcopy from pathlib import Path from typing import Any -import albumentations as A import numpy as np import torch import torchvision.transforms.v2 as transforms @@ -38,11 +36,43 @@ from vllm_omni.diffusion.models.gr00t.dataio.state_action.state_action_processor import StateActionProcessor from vllm_omni.diffusion.models.gr00t.dataio.utils import parse_modality_configs, to_json_serializable -from .image_augmentations import ( - apply_with_replay, - build_image_transformations, - build_image_transformations_albumentations, -) + +class LetterBoxTransform: + """Pad image to square dimensions by adding black bars to the smaller side.""" + + def __call__(self, img: torch.Tensor) -> torch.Tensor: + import math + + *leading_dims, c, h, w = img.shape + if h == w: + return img + max_dim = max(h, w) + pad_h, pad_w = max_dim - h, max_dim - w + pad_top, pad_left = pad_h // 2, pad_w // 2 + pad_bottom, pad_right = pad_h - pad_top, pad_w - pad_left + if leading_dims: + batch_size = math.prod(leading_dims) + img_r = img.reshape(batch_size, c, h, w) + padded = transforms.functional.pad(img_r, padding=[pad_left, pad_top, pad_right, pad_bottom], fill=0) + return padded.reshape(leading_dims + [c, max_dim, max_dim]) + return transforms.functional.pad(img, padding=[pad_left, pad_top, pad_right, pad_bottom], fill=0) + + +def _build_eval_image_transform( + image_target_size: list[int], + image_crop_size: list[int], +) -> transforms.Compose: + """Deterministic eval/inference image transform (letterbox → resize → centercrop → resize).""" + return transforms.Compose( + [ + transforms.ToImage(), + LetterBoxTransform(), + transforms.Resize(size=image_target_size), + transforms.CenterCrop(size=image_crop_size), + transforms.Resize(size=image_target_size), + ] + ) + try: from transformers import Qwen3VLProcessor @@ -54,9 +84,8 @@ # Suppress protobuf deprecation warnings warnings.filterwarnings("ignore", category=DeprecationWarning, module="google.protobuf") -### Mapping from embodiment tag to projector index. EMBODIMENT_TAG_TO_PROJECTOR_INDEX = { - ##### Pretrain embodiment ids (in base model) ##### + # Pretrain embodiment ids "oxe_droid_relative_eef_relative_joint": 24, "xdof_relative_eef_relative_joint": 27, "xdof_relative_eef_relative_joint_subtask": 27, @@ -65,7 +94,7 @@ "real_r1_pro_sharpa_relative_eef_human": 26, "real_r1_pro_sharpa_relative_eef_maxinsights": 26, "real_r1_pro_sharpa_relative_eef_mecka": 26, - ##### Posttrain embodiment ids ##### + # Posttrain embodiment ids "unitree_g1_full_body_with_waist_height_nav_cmd": 25, "unitree_g1_sonic": 11, "simpler_env_google": 0, @@ -83,10 +112,7 @@ def build_processor(model_name: str, transformers_loading_kwargs: dict) -> Qwen3 "Qwen3VLProcessor is not available. Please upgrade transformers: pip install transformers>=4.52.0" ) if model_name == "nvidia/Cosmos-Reason2-2B": - # Cosmos-Reason2-2B is a Qwen3-VL 2B backbone but the NVIDIA repo does - # not publish a Qwen3VLProcessor-compatible preprocessor. Fall back to - # the upstream Qwen3-VL repo for processor artifacts only; model - # weights are still loaded from `nvidia/Cosmos-Reason2-2B`. + # Cosmos-Reason2-2B lacks a Qwen3VLProcessor; fall back to upstream artifacts. logger.warning_once( "Substituting processor from %s because %s does not ship one. " "If you fine-tune Cosmos-Reason2-2B's tokenizer/image processor, " @@ -105,7 +131,6 @@ def __init__( model_type: str = "qwen", transformers_loading_kwargs: dict = {}, ): - ### We need to use the same processor for padding input ids and concat self.processor = build_processor(model_name, transformers_loading_kwargs) # Set padding side to 'left' for Flash Attention compatibility self.processor.tokenizer.padding_side = "left" @@ -119,7 +144,6 @@ def __call__(self, features: list[dict[str, Any]]) -> BatchFeature: for key in keys: values = [elem[key] for elem in features if key in elem] if key == "vlm_content": - # Handle vlm_content specially - extract text and images text_list = [] image_inputs = [] for v in values: @@ -145,7 +169,6 @@ def __call__(self, features: list[dict[str, Any]]) -> BatchFeature: ): raise Exception("Not implemented") else: - # state, state_mask, action and action_mask - stack to form batch dimension batch[key] = torch.from_numpy(np.stack(values)) return BatchFeature(data={"inputs": batch}) @@ -175,19 +198,16 @@ def __init__( max_action_dim: int = 29, max_action_horizon: int = 50, apply_sincos_state_encoding: bool = False, - use_albumentations: bool = False, - extra_augmentation_config: dict | None = None, use_relative_action: bool = False, embodiment_id_mapping: dict[str, int] | None = None, transformers_loading_kwargs: dict = {"trust_remote_code": True}, - # State augmentation exclude_state: bool = False, - state_dropout_prob: float = 0.0, # Normalization use_mean_std: bool = False, - # Backward-compat params (stored but not actively used) - letter_box_transform: bool = False, + **kwargs, # absorb deprecated training-only keys from saved processor_config.json ): + if kwargs: + logger.debug("Gr00tN1d7Processor: ignoring unknown keys: %s", list(kwargs)) self.modality_configs = parse_modality_configs(modality_configs) # Initialize StateActionProcessor for state/action normalization @@ -200,21 +220,14 @@ def __init__( use_relative_action=use_relative_action, ) - # Save state action processor settings self.use_percentiles = use_percentiles self.use_mean_std = use_mean_std self.clip_outliers = clip_outliers self.apply_sincos_state_encoding = apply_sincos_state_encoding self.use_relative_action = use_relative_action - self.extra_augmentation_config = extra_augmentation_config - # State augmentation settings self.exclude_state = exclude_state - self.state_dropout_prob = state_dropout_prob - - self.letter_box_transform = letter_box_transform - # Save VLM settings self.formalize_language = formalize_language self.model_name = model_name self.model_type = model_type @@ -223,7 +236,6 @@ def __init__( self.max_action_dim = max_action_dim self.max_action_horizon = max_action_horizon - # Save image processing settings self.image_crop_size = image_crop_size self.image_target_size = image_target_size self.random_rotation_angle = random_rotation_angle @@ -232,54 +244,29 @@ def __init__( # Set padding side to 'left' for Flash Attention compatibility self.processor.tokenizer.padding_side = "left" self.embodiment_id_mapping = embodiment_id_mapping or EMBODIMENT_TAG_TO_PROJECTOR_INDEX - # Merge any missing pre-trained embodiment tags into the custom mapping for k, v in EMBODIMENT_TAG_TO_PROJECTOR_INDEX.items(): if k not in self.embodiment_id_mapping: self.embodiment_id_mapping[k] = v self.shortest_image_edge = shortest_image_edge self.crop_fraction = crop_fraction - # Statistics cache (mirrors state_action_processor.statistics for serialization) self.statistics: dict[str, dict[str, dict[str, dict[str, list[float]]]]] = {} - # Choose between torchvision and albumentations transforms - self.use_albumentations = use_albumentations - if use_albumentations: - self.train_image_transform, self.eval_image_transform = build_image_transformations_albumentations( - image_target_size, - image_crop_size, - random_rotation_angle, - color_jitter_params, - shortest_image_edge, - crop_fraction, - extra_augmentation_config=self.extra_augmentation_config, - ) - else: - self.train_image_transform, self.eval_image_transform = build_image_transformations( - image_target_size, - image_crop_size, - random_rotation_angle, - color_jitter_params, - ) + # Eval/inference image transform + self.eval_image_transform = _build_eval_image_transform( + image_target_size, + image_crop_size, + ) self._collator = self.data_collator_class( model_name=model_name, model_type=model_type, transformers_loading_kwargs=transformers_loading_kwargs, ) - self.train() @property def collator(self): return self._collator - def train(self): - super().train() - self.state_action_processor.train() - - def eval(self): - super().eval() - self.state_action_processor.eval() - def set_statistics( self, statistics: dict[str, dict[str, dict[str, dict[str, list[float]]]]], @@ -394,24 +381,14 @@ def process_observation(self, observation: dict[str, Any], embodiment_tag: Embod normalized_states = torch.cat([normalized_states, torch.zeros(padding_shape)], dim=-1) transformed_observation["state"] = normalized_states - # Process images: observation values are (B, T, H, W, C) numpy arrays image_keys = modality_config["video"].modality_keys images_dict = {view: torch.from_numpy(observation[f"video.{view}"]) for view in image_keys} images = torch.stack([images_dict[view] for view in image_keys], dim=2) # (B, T, V, H, W, C) assert images.ndim == 6 B, T, V, img_H, img_W, img_C = images.shape - if self.use_albumentations: - images_flat = images.reshape(B * T * V, img_H, img_W, img_C) - pil_images = [Image.fromarray(img.numpy()) for img in images_flat] - transformed_pil, _ = apply_with_replay(self.eval_image_transform, pil_images) - transformed_stacked = torch.stack(transformed_pil) # (B*T*V, C, H_new, W_new) - _, img_C_new, img_H_new, img_W_new = transformed_stacked.shape - transformed_images = transformed_stacked.reshape(B, T * V, img_C_new, img_H_new, img_W_new).numpy() - else: - # Rearrange (B, T, V, H, W, C) to (B, T*V, C, H, W) for torchvision. - images_perm = images.permute(0, 1, 2, 5, 3, 4).reshape(B, T * V, img_C, img_H, img_W) - transformed_images = self.eval_image_transform(images_perm).numpy() + images_perm = images.permute(0, 1, 2, 5, 3, 4).reshape(B, T * V, img_C, img_H, img_W) # (B,T*V,C,H,W) + transformed_images = self.eval_image_transform(images_perm).numpy() language_key = modality_config["language"].modality_keys[0] language = [ @@ -432,7 +409,6 @@ def process_observation(self, observation: dict[str, Any], embodiment_tag: Embod embodiment_id = torch.ones(B, dtype=torch.int32) * self.embodiment_id_mapping[embodiment_tag.value] transformed_observation["embodiment_id"] = embodiment_id - # Action mask: shape (B, max_action_horizon), 1 in the valid horizon window action_config = modality_config["action"] action_horizon = len(action_config.delta_indices) assert action_horizon <= self.max_action_horizon, ( @@ -448,16 +424,7 @@ def process_observation(self, observation: dict[str, Any], embodiment_tag: Embod return BatchFeature(transformed_observation) def _apply_vlm_processing(self, images: np.ndarray, language: str) -> BatchFeature: - """ - Args: - batch: - video: [T, C, H, W] - Returns: vlm_content format for collation - """ - # Convert images to PIL format pil_images = [Image.fromarray(np.transpose(v, (1, 2, 0))) for v in images] - - # Create conversation with images and text conversation = [ { "role": "user", @@ -468,10 +435,7 @@ def _apply_vlm_processing(self, images: np.ndarray, language: str) -> BatchFeatu } ] - # Apply chat template but don't process yet - let collator handle it text = self.processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False) - - # Return vlm_content format for collation return { "vlm_content": { "text": text, @@ -490,7 +454,6 @@ def __call__( action_data = content.actions state_data = content.states - # Use StateActionProcessor to handle relative conversion and normalization norm_state_dict, normalized_actions = self.state_action_processor.apply( state=state_data, action=action_data, @@ -498,14 +461,12 @@ def __call__( ) if normalized_actions: - # Concatenate actions action_keys = self.modality_configs[embodiment_tag.value]["action"].modality_keys normalized_actions = torch.cat( [torch.from_numpy(normalized_actions[key]) for key in action_keys], dim=-1, ) # (t, d) action_dim = normalized_actions.shape[1] - # Pad action to max_action_dim normalized_actions = torch.cat( [ normalized_actions, @@ -516,7 +477,6 @@ def __call__( ], dim=-1, ) # (t, max_action_dim) - # Pad action to max_action_horizon action_horizon = normalized_actions.shape[0] assert action_horizon <= self.max_action_horizon, ( f"Action sequence length {action_horizon} exceeds max_action_horizon" @@ -533,23 +493,18 @@ def __call__( ], dim=0, ) # (max_action_horizon, max_action_dim) - # Create action mask action_mask = torch.ones_like(normalized_actions) action_mask[action_horizon:] = 0 action_mask[:, action_dim:] = 0 else: - assert not self.training, "Action is required in training mode" normalized_actions = None action_mask = None - # Concatenate states with optional dropout/noise augmentation state_keys = self.modality_configs[embodiment_tag.value]["state"].modality_keys exclude_state = self.exclude_state or getattr( self.modality_configs[embodiment_tag.value]["state"], "exclude_state", False ) - if exclude_state or ( - self.state_dropout_prob > 0 and random.random() < self.state_dropout_prob and self.training - ): + if exclude_state: normalized_states = torch.cat( [torch.from_numpy(np.zeros_like(state_data[key])) for key in state_keys], dim=-1 ) @@ -566,11 +521,7 @@ def __call__( dim=-1, ) - # Crop and resize images. - if self.training: - image_transform = self.train_image_transform - else: - image_transform = self.eval_image_transform + image_transform = self.eval_image_transform image_keys = self.modality_configs[embodiment_tag.value]["video"].modality_keys if self.formalize_language: @@ -604,33 +555,16 @@ def _get_vlm_inputs( image_keys: list[str], images: list[Image.Image], masks: dict[str, list[np.ndarray]] | None, - image_transform: transforms.Compose | A.Compose, + image_transform: transforms.Compose, language: str, ): temporal_stacked_images = {} - if self.use_albumentations: - # Use albumentations transforms - replay = None - for view in image_keys: - assert view in images, f"{view} not in {images}" - if masks is not None: - assert view in masks, f"{view} not in masks" - view_masks = masks.get(view) if masks else None - view_images = images[view] - - # Apply transforms with replay for consistency - transformed_images, replay = apply_with_replay(image_transform, view_images, view_masks, replay) - temporal_stacked_images[view] = torch.stack(transformed_images) # (T, C, H, W) - else: - if masks is not None: - raise ValueError("Mask transforms require albumentations. Set use_albumentations_transforms=True.") - # Use torchvision transforms - for view in image_keys: - assert view in images, f"{view} not in {images}" - temporal_stacked_images[view] = torch.stack( - [image_transform(img) for img in images[view]] - ) # (T, C, H, W) + if masks is not None: + raise ValueError("Mask-based transforms are not supported at inference.") + for view in image_keys: + assert view in images, f"{view} not in {images}" + temporal_stacked_images[view] = torch.stack([image_transform(img) for img in images[view]]) # (T, C, H, W) for k, v in temporal_stacked_images.items(): assert isinstance(k, str), f"{k} is not a string" @@ -657,44 +591,34 @@ def save_pretrained(self, save_directory: str | Path) -> list[Path]: "processor_class": self.__class__.__name__, "processor_kwargs": { "modality_configs": to_json_serializable(self.modality_configs), - # Image processing settings "image_crop_size": self.image_crop_size, "image_target_size": self.image_target_size, - "use_albumentations": self.use_albumentations, "random_rotation_angle": self.random_rotation_angle, "color_jitter_params": self.color_jitter_params, "shortest_image_edge": self.shortest_image_edge, "crop_fraction": self.crop_fraction, - "letter_box_transform": self.letter_box_transform, - # VLM settings "model_name": self.model_name, "model_type": self.model_type, "formalize_language": self.formalize_language, - # State action dimensions "max_state_dim": self.max_state_dim, "max_action_dim": self.max_action_dim, "max_action_horizon": self.max_action_horizon, - # StateActionProcessor settings "use_percentiles": self.use_percentiles, "use_mean_std": self.use_mean_std, "clip_outliers": self.clip_outliers, "apply_sincos_state_encoding": self.apply_sincos_state_encoding, "use_relative_action": self.use_relative_action, - # State augmentation "exclude_state": self.exclude_state, - "state_dropout_prob": self.state_dropout_prob, }, } with open(main_config_file, "w") as f: json.dump(config, f, indent=2) - # Save statistics with open(statistics_file, "w") as f: json.dump( to_json_serializable(self.state_action_processor.statistics), f, indent=2, ) - # Save embodiment id mapping with open(embodiment_id_file, "w") as f: json.dump(self.embodiment_id_mapping, f, indent=2) return [main_config_file, statistics_file, embodiment_id_file] @@ -725,9 +649,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | Path, **kwargs): processor_kwargs["statistics"] = statistics processor_kwargs["embodiment_id_mapping"] = embodiment_id_mapping - # Backfill fields that older checkpoints may not have serialized. - # Without these, __init__ defaults silently apply - correct today but - # fragile if defaults ever change. + # Backfill missing fields from older checkpoints. processor_kwargs.setdefault("model_name", "nvidia/Cosmos-Reason2-2B") processor_kwargs.setdefault("model_type", "qwen") processor_kwargs.setdefault("clip_outliers", True) @@ -743,7 +665,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | Path, **kwargs): "color_jitter_params", "use_relative_action", "exclude_state", - "state_dropout_prob", "use_mean_std", "model_name", "model_type", diff --git a/vllm_omni/diffusion/models/gr00t/policy.py b/vllm_omni/diffusion/models/gr00t/policy.py index 4e81e7ce7b6..9ab3f63e7fe 100644 --- a/vllm_omni/diffusion/models/gr00t/policy.py +++ b/vllm_omni/diffusion/models/gr00t/policy.py @@ -99,7 +99,6 @@ def __init__( else model_dir ) self.processor: BaseProcessor = AutoProcessor.from_pretrained(processor_dir) - self.processor.eval() # Store embodiment-specific configurations self.embodiment_tag = embodiment_tag @@ -141,8 +140,7 @@ def __init__( self.collate_fn = self.processor.collator # Extract and validate language configuration - # Some embodiments (e.g. OXE_DROID) define multiple language keys for - # training-time augmentation (paraphrases). At inference we only use the first key. + # Embodiments may define multiple language keys (e.g. paraphrases); use only the first at inference. language_keys = self.modality_configs["language"].modality_keys language_delta_indices = self.modality_configs["language"].delta_indices assert len(language_keys) >= 1, "At least one language key is required" @@ -226,8 +224,6 @@ def check_observation(self, observation: dict[str, Any]) -> None: # Track batch size across modalities to ensure consistency bs = -1 - # ===== VIDEO VALIDATION ===== - # Validate each video stream defined in the modality config for video_key in self.modality_configs["video"].modality_keys: assert video_key in observation["video"], f"Video key '{video_key}' must be in observation" @@ -267,8 +263,6 @@ def check_observation(self, observation: dict[str, Any]) -> None: f"Video key '{video_key}'s channel 'C' must be 3. Got {batched_video.shape[-1]}" ) - # ===== STATE VALIDATION ===== - # Validate each state stream defined in the modality config for state_key in self.modality_configs["state"].modality_keys: # Check that the expected state key exists in the observation # Must happen before indexing; see video validation above. @@ -305,8 +299,6 @@ def check_observation(self, observation: dict[str, Any]) -> None: f"{len(self.modality_configs['state'].delta_indices)}. Got {batched_state.shape[1]}" ) - # ===== LANGUAGE VALIDATION ===== - # Validate each language stream defined in the modality config for language_key in self.modality_configs["language"].modality_keys: # Check that the expected language key exists in the observation # Must happen before indexing; see video validation above. @@ -512,8 +504,6 @@ def check_observation(self, observation: dict[str, Any]) -> None: """ modality_configs = self.get_modality_config() - # ===== VIDEO VALIDATION ===== - # Check video modalities with flat key format: 'video.camera_name' for video_key in modality_configs["video"].modality_keys: # Construct flat key expected in Gr00t sim environment parsed_key = f"video.{video_key}" @@ -547,8 +537,6 @@ def check_observation(self, observation: dict[str, Any]) -> None: f"Video key '{video_key}'s channel 'C' must be 3. Got {batched_video.shape[-1]}" ) - # ===== STATE VALIDATION ===== - # Check state modalities with flat key format: 'state.state_name' for state_key in modality_configs["state"].modality_keys: # Construct flat key expected in Gr00t sim environment parsed_key = f"state.{state_key}" @@ -577,14 +565,10 @@ def check_observation(self, observation: dict[str, Any]) -> None: f"{len(modality_configs['state'].delta_indices)}. Got {batched_state.shape[1]}" ) - # ===== LANGUAGE VALIDATION ===== - # Check language modalities (special handling for DC environment compatibility) for language_key in modality_configs["language"].modality_keys: - # PATCH: Legacy compatibility for DC environments # DC envs use 'annotation.human.coarse_action' instead of 'task' if language_key == "task" and "annotation.human.coarse_action" in observation: language_key = "annotation.human.coarse_action" - # /PATCH # Check that the expected language key exists assert language_key in observation, f"Language key '{language_key}' must be in observation" @@ -631,10 +615,9 @@ def _get_action( new_obs[modality] = {} for key in self.policy.modality_configs[modality].modality_keys: if modality == "language": - # PATCH: Legacy compatibility for DC environments + # DC envs use 'annotation.human.coarse_action' instead of 'task' if key == "task" and "annotation.human.coarse_action" in observation: parsed_key = "annotation.human.coarse_action" - # /PATCH else: parsed_key = key else: From f840a4e79cb378467f4572b89f708b7e76827fe9 Mon Sep 17 00:00:00 2001 From: Zhengyuan Su Date: Mon, 1 Jun 2026 19:03:59 +0700 Subject: [PATCH 5/8] fix(gr00t): deterministic inference, vLLM FA, e2e precision test Signed-off-by: Zhengyuan Su --- .../examples/online_serving/gr00t.md | 5 +- examples/online_serving/gr00t/README.md | 116 ------- .../gr00t/molmospace_gr00t_eval_demo.py | 109 ------- .../online_serving/gr00t/openpi_client.py | 308 ------------------ examples/online_serving/gr00t/run_server.sh | 25 -- recipes/NVIDIA/GR00T-N1.7.md | 93 ++++++ tests/e2e/online_serving/test_gr00t_openpi.py | 123 ++++++- vllm_omni/diffusion/models/gr00t/__init__.py | 7 +- .../configs/embodiment/embodiment_configs.py | 6 - .../models/gr00t/configs/model/gr00t_n1d7.py | 1 + .../models/gr00t/dataio/collator/__init__.py | 6 - .../models/gr00t/dataio/collator/collators.py | 27 -- .../models/gr00t/dataio/interfaces.py | 105 ------ .../models/gr00t/modeling/gr00t_n1d7.py | 129 +++----- .../models/gr00t/modeling/modules/dit.py | 244 ++++---------- .../modules/embodiment_conditioned_mlp.py | 49 +-- .../modeling/modules/flowmatching_modules.py | 49 +-- .../gr00t/modeling/processing_gr00t_n1d7.py | 14 +- .../diffusion/models/gr00t/pipeline_gr00t.py | 48 +-- vllm_omni/diffusion/models/gr00t/policy.py | 275 +--------------- .../diffusion/models/gr00t/policy_base.py | 132 -------- .../models/internvla_a1/adapter_qwen3_vl.py | 84 ++++- vllm_omni/diffusion/registry.py | 1 - vllm_omni/diffusion/utils/flow_matching.py | 28 ++ 24 files changed, 448 insertions(+), 1536 deletions(-) delete mode 100644 examples/online_serving/gr00t/README.md delete mode 100644 examples/online_serving/gr00t/molmospace_gr00t_eval_demo.py delete mode 100644 examples/online_serving/gr00t/openpi_client.py delete mode 100755 examples/online_serving/gr00t/run_server.sh create mode 100644 recipes/NVIDIA/GR00T-N1.7.md delete mode 100644 vllm_omni/diffusion/models/gr00t/dataio/collator/__init__.py delete mode 100755 vllm_omni/diffusion/models/gr00t/dataio/collator/collators.py delete mode 100644 vllm_omni/diffusion/models/gr00t/dataio/interfaces.py mode change 100755 => 100644 vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py delete mode 100644 vllm_omni/diffusion/models/gr00t/policy_base.py create mode 100644 vllm_omni/diffusion/utils/flow_matching.py diff --git a/docs/user_guide/examples/online_serving/gr00t.md b/docs/user_guide/examples/online_serving/gr00t.md index b42de45b857..13e2183dc58 100644 --- a/docs/user_guide/examples/online_serving/gr00t.md +++ b/docs/user_guide/examples/online_serving/gr00t.md @@ -1,7 +1,5 @@ # GR00T OpenPI Serving -Source . - GR00T N1.7 is served through `/v1/realtime/robot/openpi`. The endpoint uses the OpenPI msgpack-numpy websocket protocol and returns GR00T actions as `dict[str, np.ndarray]`. ## Prerequisites @@ -22,5 +20,4 @@ The deploy config is `vllm_omni/deploy/Gr00tN1d7.yaml`. It registers `Gr00tN1d7P Unlike single-stream policies that return one ndarray, GR00T returns a per-action-key dictionary. vLLM-Omni preserves that dictionary under `multimodal_output["actions"]`, and the OpenPI endpoint sends it as the websocket success payload. -??? abstract "Example README" - --8<-- "examples/online_serving/gr00t/README.md" +See [`recipes/NVIDIA/GR00T-N1.7.md`](../../../../recipes/NVIDIA/GR00T-N1.7.md) for a full serving recipe with hardware requirements and verification steps. diff --git a/examples/online_serving/gr00t/README.md b/examples/online_serving/gr00t/README.md deleted file mode 100644 index fa66f5abc20..00000000000 --- a/examples/online_serving/gr00t/README.md +++ /dev/null @@ -1,116 +0,0 @@ -# GR00T-N1.7 OpenPI Example - -This example shows how to serve NVIDIA Isaac GR00T-N1.7 with `vllm serve --omni` -and connect a compatible OpenPI websocket client. The deployment is configured -for the DROID embodiment (`OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT`) and exposes -the GR00T action keys `eef_9d`, `gripper_position`, and `joint_position`. - -## Files - -- `run_server.sh`: launch GR00T-N1.7 OpenPI serving -- `openpi_client.py`: minimal websocket client that sends a synthetic GR00T observation -- `molmospace_gr00t_eval_demo.py`: MolmoSpaces evaluation demo wrapper - -## Environment requirements - -- `run_server.sh`, `vllm serve`, and `openpi_client.py`: use the local - `vllm-omni` environment. -- `openpi_client.py` extra deps: - -```bash -pip install openpi-client websockets -``` - -Optional MolmoSpaces dependencies: - -- `molmospace_gr00t_eval_demo.py` requires a working MolmoSpaces checkout with - the `examples.gr00t_openpi` package importable on `PYTHONPATH`, and uses EGL - for MuJoCo. The default asset cache location is - `$HOME/.cache/molmospaces/gr00t-assets`. - -## Start the server - -From the repository root: - -```bash -CUDA_VISIBLE_DEVICES=0 \ -examples/online_serving/gr00t/run_server.sh -``` - -The launcher honors: - -- `CUDA_VISIBLE_DEVICES`: GPU selection (default `0`) -- `DEPLOY_CONFIG`: stage config path (default `vllm_omni/deploy/Gr00tN1d7.yaml`) -- `HOST` / `PORT`: bind address (defaults `127.0.0.1:8000`) -- `SERVED_MODEL_NAME`: alias served via the OpenAI route (default `gr00t-n1d7`) -- `VLLM_WORKER_MULTIPROC_METHOD`: defaults to `spawn` - -The websocket endpoint is: - -- `ws://127.0.0.1:8000/v1/realtime/robot/openpi` - -The server handshake advertises the configured embodiment, image resolution -(`[180, 320]`), action horizon, and action key set through `policy_server_config` -in `vllm_omni/deploy/Gr00tN1d7.yaml`. - -## Run the client - -From the repository root: - -Environment: - -- run this in the `vllm-omni` repo environment -- if imports are missing, install `openpi-client` and `websockets` - -```bash -python examples/online_serving/gr00t/openpi_client.py \ - --host 127.0.0.1 \ - --port 8000 -``` - -The client sends one synthetic two-frame observation crafted for GR00T's DROID -contract: - -- exterior + wrist video tensors shaped `(1, 2, 180, 320, 3)` with a slowly - varying gradient pattern (constant frames have triggered SVD failures in the - Qwen3-VL backbone in practice) -- `state` with `eef_9d` (1,1,9), `gripper_position` (1,1,1), and - `joint_position` (1,1,7) -- `language` keyed by `annotation.language.language_instruction` - -It validates: - -- GR00T metadata contract (`image_resolution`, `action_horizon`, `action_keys`, - `embodiment_tag`) -- presence and 3D shape of each expected action key -- finite action values -- post-call reset response - -## Run MolmoSpaces benchmarks - -`molmospace_gr00t_eval_demo.py` evaluates GR00T-N1.7 through the same vLLM -OpenPI server on MolmoSpaces benchmarks. Install MolmoSpaces and make sure -`examples.gr00t_openpi.gr00t_openpi_policy` is importable, then run: - -```bash -BENCH="${MOLMOSPACES_BENCHMARK_DIR}/20260327/ithor/FrankaCloseHardBench" -BENCH="${BENCH}/FrankaCloseHardBench_20260206_json_benchmark" -PYTHONPATH="${MOLMOSPACES_ROOT}" \ -MUJOCO_EGL_DEVICE_ID=0 \ -python examples/online_serving/gr00t/molmospace_gr00t_eval_demo.py \ - --host 127.0.0.1 \ - --port 8000 \ - --benchmark_dir "${BENCH}" \ - --output_dir outputs/gr00t/molmospaces \ - --max_episodes 1 \ - --task_horizon_steps 240 -``` - -`PYTHONPATH` must include the MolmoSpaces workspace root so that -`examples.gr00t_openpi.gr00t_openpi_policy` is importable. -`MUJOCO_EGL_DEVICE_ID` selects the EGL device for headless rendering; set -it to match the physical GPU index (may differ from `CUDA_VISIBLE_DEVICES`). - -The wrapper subclasses `Gr00tOpenPIPolicyConfig` / `Gr00tOpenPIEvalConfig` from -the MolmoSpaces workspace and overrides the policy backend host/port at runtime -so the eval harness talks to your local vLLM-Omni GR00T server. diff --git a/examples/online_serving/gr00t/molmospace_gr00t_eval_demo.py b/examples/online_serving/gr00t/molmospace_gr00t_eval_demo.py deleted file mode 100644 index a60c01be9b0..00000000000 --- a/examples/online_serving/gr00t/molmospace_gr00t_eval_demo.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -"""Drive MolmoSpaces evaluation against a running vLLM GR00T-N1.7 server. - -Wires the MolmoSpaces evaluation harness to the ``Gr00tOpenPIEvalConfig`` / -``Gr00tOpenPIPolicyConfig`` defined under -``examples.gr00t_openpi.gr00t_openpi_policy`` in the MolmoSpaces workspace and -points its policy backend at the local vLLM-Omni websocket server. -""" - -from __future__ import annotations - -import argparse -import os -import sys -from pathlib import Path - -# Remove this script's directory from sys.path so that `import openpi_client` -# resolves to the installed openpi-client package rather than our local -# openpi_client.py (same guard used in openpi_client.py). -_example_dir = str(Path(__file__).resolve().parent) -if sys.path and sys.path[0] == _example_dir: - sys.path.pop(0) - -os.environ.setdefault("MUJOCO_GL", "egl") -os.environ.setdefault("PYOPENGL_PLATFORM", "egl") -os.environ.setdefault("MLSPACES_ASSETS_DIR", str(Path.home() / ".cache" / "molmospaces" / "gr00t-assets")) - -_DEMO_HOST = os.environ.get("VLLM_OMNI_DEMO_HOST", "127.0.0.1") -_DEMO_PORT = int(os.environ.get("VLLM_OMNI_DEMO_PORT", "8000")) - -# Import the GR00T MolmoSpaces base configs at module top level so the -# subclasses below are pickle-resolvable (worker processes import this module -# fresh via __main__). -from examples.gr00t_openpi.gr00t_openpi_policy import ( # noqa: E402 - Gr00tOpenPIEvalConfig, - Gr00tOpenPIPolicyConfig, -) - - -class Gr00tVllmOmniPolicyConfig(Gr00tOpenPIPolicyConfig): - host: str = _DEMO_HOST - port: int = _DEMO_PORT - - -class Gr00tVllmOmniEvalConfig(Gr00tOpenPIEvalConfig): - policy_config: Gr00tVllmOmniPolicyConfig = Gr00tVllmOmniPolicyConfig() - - -def main() -> int: - parser = argparse.ArgumentParser() - parser.add_argument("--host", default="127.0.0.1") - parser.add_argument("--port", type=int, default=8000) - parser.add_argument( - "--benchmark_dir", - required=True, - help=( - "Path to a MolmoSpaces benchmark directory, for example " - "$MOLMOSPACES_BENCHMARK_DIR/20260327/ithor/FrankaCloseHardBench/" - "FrankaCloseHardBench_20260206_json_benchmark" - ), - ) - parser.add_argument("--max_episodes", type=int, default=1) - parser.add_argument("--task_horizon_steps", type=int, default=80) - parser.add_argument( - "--output_dir", - required=True, - help="Directory to write evaluation outputs (created if missing).", - ) - parser.add_argument("--episode_idx", type=int, default=None) - args = parser.parse_args() - - os.environ["VLLM_OMNI_DEMO_HOST"] = args.host - os.environ["VLLM_OMNI_DEMO_PORT"] = str(args.port) - Gr00tVllmOmniPolicyConfig.model_fields["host"].default = args.host - Gr00tVllmOmniPolicyConfig.model_fields["port"].default = args.port - - # Import after env vars are set so MuJoCo picks EGL. - from molmo_spaces.evaluation import run_evaluation - - cfg_cls = Gr00tVllmOmniEvalConfig - - output_dir = args.output_dir - Path(output_dir).mkdir(parents=True, exist_ok=True) - - print(f"[eval] benchmark_dir={args.benchmark_dir}") - print(f"[eval] max_episodes={args.max_episodes} task_horizon_steps={args.task_horizon_steps}") - print(f"[eval] remote policy: ws://{args.host}:{args.port}/v1/realtime/robot/openpi") - - results = run_evaluation( - eval_config_cls=cfg_cls, - benchmark_dir=Path(args.benchmark_dir), - max_episodes=args.max_episodes, - task_horizon_steps=args.task_horizon_steps, - num_workers=1, - use_wandb=False, - output_dir=output_dir, - episode_idx=args.episode_idx, - ) - - print(f"[eval] success={results.success_count}/{results.total_count} ({results.success_rate:.1%})") - print(f"[eval] output_dir={results.output_dir}") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/examples/online_serving/gr00t/openpi_client.py b/examples/online_serving/gr00t/openpi_client.py deleted file mode 100644 index 1661b310f37..00000000000 --- a/examples/online_serving/gr00t/openpi_client.py +++ /dev/null @@ -1,308 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -"""Minimal GR00T-N1.7 OpenPI websocket client demo. - -Sends a single synthetic observation crafted for the GR00T DROID embodiment -contract advertised by ``vllm_omni/deploy/Gr00tN1d7.yaml``: - -- two-frame video history at ``(180, 320)`` per camera (exterior + wrist) -- ``state`` with ``eef_9d`` (1,1,9), ``gripper_position`` (1,1,1) and - ``joint_position`` (1,1,7) -- ``language`` keyed by ``annotation.language.language_instruction`` - -Expects the server to return an action dict with keys -``{"eef_9d", "gripper_position", "joint_position"}``. - -The synthetic image uses a slowly varying gradient rather than a constant. -Constant frames have caused SVD failures in the Qwen3-VL visual backbone -during local smoke testing, so we deliberately keep some spatial variance. -""" - -from __future__ import annotations - -import argparse -import sys -import uuid -from dataclasses import dataclass -from pathlib import Path -from typing import Any - -import numpy as np -from vllm.logger import init_logger - -try: - import websockets.sync.client -except ImportError as exc: # pragma: no cover - runtime dependency guard - raise ImportError("GR00T OpenPI example requires `websockets`.") from exc - -# NOTE: this directory does NOT contain a local file named ``openpi_client.py`` -# clashing with the installed package (this file imports nothing from it), but -# we apply the same defensive ``sys.path`` rewrite so the script works even -# when launched as ``python openpi_client.py`` from inside this directory. -try: - example_dir = str(Path(__file__).resolve().parent) - removed_path = False - if sys.path and sys.path[0] == example_dir: - sys.path.pop(0) - removed_path = True - try: - from openpi_client import msgpack_numpy - finally: - if removed_path: - sys.path.insert(0, example_dir) -except ImportError as exc: # pragma: no cover - runtime dependency guard - raise ImportError("GR00T OpenPI example requires `openpi-client`.") from exc - -logger = init_logger(__name__) - -PING_INTERVAL_SECS = 300 -PING_TIMEOUT_SECS = 3600 -DEFAULT_HOST = "127.0.0.1" -DEFAULT_PORT = 8000 -DEFAULT_PATH = "/v1/realtime/robot/openpi" -DEFAULT_PROMPT = "pick up the object and place it in the bin" -LANGUAGE_KEY = "annotation.language.language_instruction" -EXPECTED_ACTION_KEYS = ("eef_9d", "gripper_position", "joint_position") -IMAGE_HEIGHT = 180 -IMAGE_WIDTH = 320 - - -@dataclass(frozen=True) -class Gr00tServerMetadata: - image_resolution: tuple[int, int] - action_horizon: int - action_keys: tuple[str, ...] - embodiment_tag: str - needs_session_id: bool - - @classmethod - def from_dict(cls, payload: dict[str, Any]) -> Gr00tServerMetadata: - required_keys = ( - "image_resolution", - "action_horizon", - "action_keys", - "embodiment_tag", - "needs_session_id", - ) - missing_keys = [key for key in required_keys if key not in payload] - if missing_keys: - raise ValueError(f"Missing GR00T metadata keys: {missing_keys}") - - image_resolution = payload["image_resolution"] - if not isinstance(image_resolution, (list, tuple)) or len(image_resolution) != 2: - raise ValueError(f"Invalid image_resolution: {image_resolution!r}") - - return cls( - image_resolution=(int(image_resolution[0]), int(image_resolution[1])), - action_horizon=int(payload["action_horizon"]), - action_keys=tuple(str(k) for k in payload["action_keys"]), - embodiment_tag=str(payload["embodiment_tag"]), - needs_session_id=bool(payload["needs_session_id"]), - ) - - -class OpenPIWebsocketClient: - def __init__( - self, - *, - host: str = DEFAULT_HOST, - port: int = DEFAULT_PORT, - path: str = DEFAULT_PATH, - ) -> None: - self._uri = f"ws://{host}:{port}{path}" - self._packer = msgpack_numpy.Packer() - self._ws, self._server_metadata = self._connect() - - def _connect(self): - logger.info("Connecting to %s", self._uri) - conn = websockets.sync.client.connect( - self._uri, - compression=None, - max_size=None, - ping_interval=PING_INTERVAL_SECS, - ping_timeout=PING_TIMEOUT_SECS, - ) - metadata = msgpack_numpy.unpackb(conn.recv()) - if not isinstance(metadata, dict): - raise TypeError(f"Expected dict metadata from server, got {type(metadata)!r}") - return conn, metadata - - def get_server_metadata(self) -> dict[str, Any]: - return dict(self._server_metadata) - - def infer(self, obs: dict[str, Any]) -> dict[str, np.ndarray]: - payload = dict(obs) - payload["endpoint"] = "infer" - self._ws.send(self._packer.pack(payload)) - response = self._ws.recv() - if isinstance(response, str): - raise RuntimeError(f"Inference failed: {response}") - decoded = msgpack_numpy.unpackb(response) - if isinstance(decoded, dict) and decoded.get("type") == "error": - raise RuntimeError(f"GR00T server inference failed: {decoded.get('message')!r}") - if not isinstance(decoded, dict): - raise TypeError(f"Expected dict actions from GR00T server, got {type(decoded)!r}") - return {str(key): np.asarray(value, dtype=np.float32) for key, value in decoded.items()} - - def reset(self, reset_info: dict[str, Any] | None = None) -> str: - payload = dict(reset_info or {}) - payload["endpoint"] = "reset" - self._ws.send(self._packer.pack(payload)) - response = self._ws.recv() - if isinstance(response, str): - return response - decoded = msgpack_numpy.unpackb(response) - if not isinstance(decoded, dict) or decoded.get("status") != "reset successful": - raise RuntimeError(f"Unexpected reset response: {decoded!r}") - return str(decoded["status"]) - - def close(self) -> None: - self._ws.close() - - -def make_synthetic_frame(height: int = IMAGE_HEIGHT, width: int = IMAGE_WIDTH, *, seed: int = 0) -> np.ndarray: - """Produce a deterministic, non-constant RGB frame. - - Constant frames have triggered SVD failures in the Qwen3-VL backbone in - practice. We blend a horizontal and vertical gradient so the image has - real spatial variance while staying reproducible. - """ - - rng = np.random.default_rng(seed) - y_grad = np.linspace(0, 255, height, dtype=np.float32)[:, None] - x_grad = np.linspace(0, 255, width, dtype=np.float32)[None, :] - base = (0.5 * y_grad + 0.5 * x_grad).astype(np.float32) - frame = np.stack([base, np.flipud(base), base.T[:height, :width] if base.shape[0] == width else base], axis=-1) - if frame.shape != (height, width, 3): - # Fall back to a tiled gradient if the transpose path was incompatible. - frame = np.stack([base, 255.0 - base, (base + 64.0) % 256.0], axis=-1) - frame = frame + rng.uniform(-4.0, 4.0, size=frame.shape).astype(np.float32) - return np.clip(frame, 0, 255).astype(np.uint8) - - -def make_synthetic_observation(*, prompt: str, session_id: str) -> dict[str, Any]: - """Build a single GR00T observation with two-frame video history.""" - - exo_t0 = make_synthetic_frame(seed=1) - exo_t1 = make_synthetic_frame(seed=2) - wrist_t0 = make_synthetic_frame(seed=3) - wrist_t1 = make_synthetic_frame(seed=4) - - video = { - "exterior_image_1_left": np.stack([exo_t0, exo_t1])[None, ...], - "wrist_image_left": np.stack([wrist_t0, wrist_t1])[None, ...], - } - # eef_9d = xyz (3) + rot6d (6). A zero rot6d is rank-deficient and causes - # Rotation.from_matrix() to fail with "SVD did not converge". Use an identity - # pose instead: xyz=[0,0,0], rot6d=[1,0,0,0,1,0] (first two cols of I). - eef_9d_identity = np.array([0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0], dtype=np.float32) - state = { - "eef_9d": eef_9d_identity.reshape(1, 1, 9), - "gripper_position": np.zeros((1, 1, 1), dtype=np.float32), - "joint_position": np.zeros((1, 1, 7), dtype=np.float32), - } - language = {LANGUAGE_KEY: [[prompt]]} - - return { - "session_id": session_id, - "video": video, - "state": state, - "language": language, - } - - -def validate_actions( - actions: dict[str, np.ndarray], - *, - expected_action_horizon: int, -) -> None: - missing = [k for k in EXPECTED_ACTION_KEYS if k not in actions] - if missing: - raise AssertionError(f"Missing action keys from server response: {missing}") - - expected_dims = {"eef_9d": 9, "gripper_position": 1, "joint_position": 7} - for key, last_dim in expected_dims.items(): - action = actions[key] - if action.ndim != 3: - raise AssertionError(f"Action {key!r} must be 3D, got shape {action.shape}") - if action.shape[1] != expected_action_horizon: - raise AssertionError( - f"Action {key!r} horizon mismatch: expected {expected_action_horizon}, got {action.shape[1]}" - ) - if action.shape[-1] != last_dim: - raise AssertionError(f"Action {key!r} trailing dim mismatch: expected {last_dim}, got {action.shape[-1]}") - if not np.isfinite(action).all(): - raise AssertionError(f"Action {key!r} contains non-finite values") - - -def run_policy_session( - *, - host: str = DEFAULT_HOST, - port: int = DEFAULT_PORT, - path: str = DEFAULT_PATH, - prompt: str = DEFAULT_PROMPT, - session_id: str | None = None, -) -> dict[str, Any]: - session_id = session_id or str(uuid.uuid4()) - observation = make_synthetic_observation(prompt=prompt, session_id=session_id) - - client = OpenPIWebsocketClient(host=host, port=port, path=path) - try: - metadata = client.get_server_metadata() - actions = client.infer(observation) - reset_status = client.reset({"session_id": session_id}) - return { - "metadata": metadata, - "actions": actions, - "reset_status": reset_status, - "session_id": session_id, - } - finally: - client.close() - - -def format_action_summary(key: str, action: np.ndarray) -> str: - return ( - f"Action {key!r}: shape={tuple(action.shape)} dtype={action.dtype} " - f"min={float(action.min()):.6f} max={float(action.max()):.6f}" - ) - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="GR00T-N1.7 OpenPI websocket client demo.") - parser.add_argument("--host", default=DEFAULT_HOST) - parser.add_argument("--port", type=int, default=DEFAULT_PORT) - parser.add_argument("--path", default=DEFAULT_PATH) - parser.add_argument("--prompt", default=DEFAULT_PROMPT) - parser.add_argument("--session-id", default=None) - return parser.parse_args() - - -def main() -> int: - args = parse_args() - result = run_policy_session( - host=args.host, - port=args.port, - path=args.path, - prompt=args.prompt, - session_id=args.session_id, - ) - - server_metadata = Gr00tServerMetadata.from_dict(result["metadata"]) - validate_actions(result["actions"], expected_action_horizon=server_metadata.action_horizon) - - print(f"Server embodiment: {server_metadata.embodiment_tag}") - print(f"Server image_resolution: {server_metadata.image_resolution}") - print(f"Server action_horizon: {server_metadata.action_horizon}") - print(f"Server action_keys: {server_metadata.action_keys}") - for key in sorted(result["actions"]): - print(format_action_summary(key, result["actions"][key])) - print(f"Reset status: {result['reset_status']}") - print(f"Session ID: {result['session_id']}") - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/examples/online_serving/gr00t/run_server.sh b/examples/online_serving/gr00t/run_server.sh deleted file mode 100755 index fc4ca79b7a8..00000000000 --- a/examples/online_serving/gr00t/run_server.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/usr/bin/env bash -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -set -euo pipefail - -MODEL="${MODEL:-nvidia/GR00T-N1.7-3B}" -HOST="${HOST:-127.0.0.1}" -PORT="${PORT:-8000}" -DEPLOY_CONFIG="${DEPLOY_CONFIG:-vllm_omni/deploy/Gr00tN1d7.yaml}" -SERVED_MODEL_NAME="${SERVED_MODEL_NAME:-gr00t-n1d7}" - -args=( - serve - "$MODEL" - --omni - --host "$HOST" - --port "$PORT" - --served-model-name "$SERVED_MODEL_NAME" - --stage-configs-path "$DEPLOY_CONFIG" - --disable-log-stats -) - -CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}" \ -VLLM_WORKER_MULTIPROC_METHOD="${VLLM_WORKER_MULTIPROC_METHOD:-spawn}" \ -vllm "${args[@]}" diff --git a/recipes/NVIDIA/GR00T-N1.7.md b/recipes/NVIDIA/GR00T-N1.7.md new file mode 100644 index 00000000000..61615b18da6 --- /dev/null +++ b/recipes/NVIDIA/GR00T-N1.7.md @@ -0,0 +1,93 @@ +# GR00T-N1.7 + +> NVIDIA Isaac GR00T-N1.7-3B robot VLA policy served over the OpenPI WebSocket protocol + +## Summary + +- Vendor: NVIDIA +- Model: `nvidia/GR00T-N1.7-3B` +- Task: Vision-Language-Action (VLA) inference for robot manipulation +- Mode: Online serving via OpenPI WebSocket endpoint +- Maintainer: timzsu + +## When to use this recipe + +Use this recipe when you need to serve GR00T-N1.7 as a real-time robot policy +over the OpenPI WebSocket API. It configures the DROID embodiment +(`OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT`) and exposes the standard DROID action +keys (`eef_9d`, `gripper_position`, `joint_position`) with action horizon 40. + +## References + +- Upstream model: +- Upstream codebase: +- OpenPI client library: +- Pipeline: `vllm_omni.diffusion.models.gr00t.pipeline_gr00t.Gr00tN1d7Pipeline` +- Deploy config: [`vllm_omni/deploy/Gr00tN1d7.yaml`](../../vllm_omni/deploy/Gr00tN1d7.yaml) +- Docs: [`docs/user_guide/examples/online_serving/gr00t.md`](../../docs/user_guide/examples/online_serving/gr00t.md) +- E2E test: [`tests/e2e/online_serving/test_gr00t_openpi.py`](../../tests/e2e/online_serving/test_gr00t_openpi.py) + +## Hardware Support + +This recipe documents one CUDA GPU serving configuration. GR00T-N1.7-3B fits +on a single 48 GB GPU; smaller GPUs have not been validated. + +## GPU + +### 1x RTX 6000 Ada 48 GB (tested) + +#### Environment + +- OS: Linux +- Python: 3.11+ +- Driver / runtime: NVIDIA CUDA +- vLLM-Omni version or commit: Use the commit you are deploying from +- Extra Python deps: `pip install openpi-client websockets` + +#### Command + +```bash +vllm serve nvidia/GR00T-N1.7-3B \ + --omni \ + --host 127.0.0.1 \ + --port 8000 \ + --served-model-name gr00t-n1d7 \ + --stage-configs-path vllm_omni/deploy/Gr00tN1d7.yaml \ + --disable-log-stats +``` + +The WebSocket endpoint is `ws://127.0.0.1:8000/v1/realtime/robot/openpi`. The +server handshake message (first frame after connect) is a msgpack-encoded dict +with `action_horizon`, `action_keys`, `embodiment_tag`, and `needs_session_id`. + +#### Verification + +```python +from tests.gr00t.openpi_client_helper import run_policy_session, validate_session_result +validate_session_result(run_policy_session(host="127.0.0.1", port=8000)) +``` + +Or run the e2e test suite: + +```bash +python -m pytest tests/e2e/online_serving/test_gr00t_openpi.py -v +``` + +The test sends a synthetic two-frame DROID observation and checks: + +- GR00T metadata contract: `image_resolution`, `action_horizon`, `action_keys`, `embodiment_tag` +- Action shapes: `eef_9d (1,40,9)`, `gripper_position (1,40,1)`, `joint_position (1,40,7)` +- All action values are finite float32 +- Reset response is `"reset successful"` + +#### Notes + +- Only `max_num_seqs: 1` is supported (configured in the deploy YAML); GR00T + policy state is per-session and not designed for concurrent batching. +- To switch embodiment, edit `embodiment_tag` under both `model_config` and + `policy_server_config` in `vllm_omni/deploy/Gr00tN1d7.yaml`. Supported values: + `OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` (default), `XDOF`, `XDOF_SUBTASK`, + `REAL_G1`, `REAL_R1_PRO_SHARPA`, `LIBERO_PANDA`, `SIMPLER_ENV_GOOGLE`, + `SIMPLER_ENV_WIDOWX`. +- GR00T weights are loaded directly by `Gr00tPolicy` via `AutoModel.from_pretrained`; + the pipeline's `load_weights` is intentionally a no-op. diff --git a/tests/e2e/online_serving/test_gr00t_openpi.py b/tests/e2e/online_serving/test_gr00t_openpi.py index c40137232b9..0d80d0fbf07 100644 --- a/tests/e2e/online_serving/test_gr00t_openpi.py +++ b/tests/e2e/online_serving/test_gr00t_openpi.py @@ -4,6 +4,7 @@ import os +import numpy as np import pytest from tests.gr00t import openpi_client_helper as openpi_client @@ -22,7 +23,7 @@ model=MODEL, stage_config_path=get_deploy_config_path("Gr00tN1d7.yaml"), server_args=["--disable-log-stats"], - env_dict={"VLLM_DISABLE_COMPILE_CACHE": "1"}, + env_dict={"VLLM_DISABLE_COMPILE_CACHE": "1", "GR00T_NOISE_SEED": "42"}, init_timeout=1200, stage_init_timeout=900, ), @@ -30,6 +31,84 @@ ) ] +# Reference values captured from the Isaac-GR00T ZMQ server (nvidia/GR00T-N1.7-3B, +# OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT embodiment) using build_droid_observation() +# inputs: zero images (256×256), identity eef_9d, zero gripper/joint, "pick up the object". +# Outputs are bit-reproducible across resets and across runs (max_diff=0.0). +_REF_EEF_9D = np.array( + [ + [ + 0.014888550154864788, + -0.00039258378092199564, + -0.013574761338531971, + 0.9999837875366211, + 0.005678884219378233, + -0.00034708858584053814, + -0.005677700508385897, + 0.9999783635139465, + 0.0033210734836757183, + ], + [ + 0.020188070833683014, + -0.0003105341165792197, + -0.02232760563492775, + 0.9997346997261047, + 0.02301345206797123, + 0.000983047066256404, + -0.02302766777575016, + 0.9995648860931396, + 0.018431292846798897, + ], + [ + -0.007266733795404434, + -0.05537768080830574, + 0.03667901083827019, + 0.992686927318573, + 0.1206541359424591, + 0.003906540106981993, + -0.12024091929197311, + 0.9853788614273071, + 0.12070894986391068, + ], + ], + dtype=np.float32, +) # rows = step 0, step 4, step 39 + +_REF_GRIPPER = np.array([[0.0], [0.0078125], [0.939453125]], dtype=np.float32) # steps 0, 4, 39 + +_REF_JOINT = np.array( + [ + [ + -0.0010484338272362947, + 0.0014262489276006818, + -0.003565810853615403, + -3.846167237497866e-05, + -0.0002604846959002316, + 0.008521700277924538, + -0.006872728932648897, + ], + [ + -0.009435676969587803, + 0.0021475711837410927, + -0.0031688229646533728, + -1.8328893929719925e-05, + 0.0005945992306806147, + 0.019159257411956787, + 0.0009468861389905214, + ], + [ + -0.02944088727235794, + -0.08419207483530045, + -0.0251418836414814, + 0.00540524534881115, + 0.04752273112535477, + -0.012884500436484814, + 0.024298785254359245, + ], + ], + dtype=np.float32, +) # rows = step 0, step 4, step 39 + @pytest.mark.advanced_model @pytest.mark.diffusion @@ -42,3 +121,45 @@ def test_gr00t_n1d7_openpi_online(omni_server) -> None: session_id="gr00t-online-e2e", ) openpi_client.validate_session_result(result) + + +@pytest.mark.advanced_model +@pytest.mark.diffusion +@hardware_test(res={"cuda": "H100"}, num_cards=1) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_gr00t_n1d7_openpi_precision(omni_server) -> None: + """Assert actions match Isaac-GR00T reference (GR00T_NOISE_SEED=42, zero inputs). + + atol=1e-2 covers the ~0.006 max diff from flash-attn 2.7.4 (Isaac-GR00T) vs + vllm.vllm_flash_attn (vLLM-Omni), compounding over 50 denoising steps. + """ + client = openpi_client.OpenPIWebsocketClient(host=omni_server.host, port=omni_server.port) + try: + obs = openpi_client.build_droid_observation(session_id="gr00t-precision-e2e") + client.reset({}) + actions = client.infer(obs) + finally: + client.close() + + steps = [0, 4, 39] + np.testing.assert_allclose( + actions["eef_9d"][0, steps, :], + _REF_EEF_9D, + atol=1e-2, + rtol=0.0, + err_msg="eef_9d action mismatch vs Isaac-GR00T reference", + ) + np.testing.assert_allclose( + actions["gripper_position"][0, steps, :], + _REF_GRIPPER, + atol=1e-2, + rtol=0.0, + err_msg="gripper_position action mismatch vs Isaac-GR00T reference", + ) + np.testing.assert_allclose( + actions["joint_position"][0, steps, :], + _REF_JOINT, + atol=1e-2, + rtol=0.0, + err_msg="joint_position action mismatch vs Isaac-GR00T reference", + ) diff --git a/vllm_omni/diffusion/models/gr00t/__init__.py b/vllm_omni/diffusion/models/gr00t/__init__.py index 5a93bd798c6..a10202a40bc 100644 --- a/vllm_omni/diffusion/models/gr00t/__init__.py +++ b/vllm_omni/diffusion/models/gr00t/__init__.py @@ -1,9 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm_omni.diffusion.models.gr00t.pipeline_gr00t import ( - Gr00tN1d7Pipeline, - get_gr00t_n1d7_post_process_func, -) +from vllm_omni.diffusion.models.gr00t.pipeline_gr00t import Gr00tN1d7Pipeline -__all__ = ["Gr00tN1d7Pipeline", "get_gr00t_n1d7_post_process_func"] +__all__ = ["Gr00tN1d7Pipeline"] diff --git a/vllm_omni/diffusion/models/gr00t/configs/embodiment/embodiment_configs.py b/vllm_omni/diffusion/models/gr00t/configs/embodiment/embodiment_configs.py index c7754c2ac9a..78411b88ec9 100644 --- a/vllm_omni/diffusion/models/gr00t/configs/embodiment/embodiment_configs.py +++ b/vllm_omni/diffusion/models/gr00t/configs/embodiment/embodiment_configs.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from vllm_omni.diffusion.models.gr00t.dataio.embodiment_tags import EmbodimentTag from vllm_omni.diffusion.models.gr00t.dataio.types import ( ActionConfig, ActionFormat, @@ -244,8 +243,3 @@ ), }, } - - -def register_modality_config(config: dict, embodiment_tag: EmbodimentTag = EmbodimentTag.NEW_EMBODIMENT): - assert embodiment_tag.value not in MODALITY_CONFIGS, f"Embodiment tag {embodiment_tag} already registered" - MODALITY_CONFIGS[embodiment_tag.value] = config diff --git a/vllm_omni/diffusion/models/gr00t/configs/model/gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/configs/model/gr00t_n1d7.py index 9eb3ee84945..c938e06d097 100644 --- a/vllm_omni/diffusion/models/gr00t/configs/model/gr00t_n1d7.py +++ b/vllm_omni/diffusion/models/gr00t/configs/model/gr00t_n1d7.py @@ -90,6 +90,7 @@ class Gr00tN1d7Config(PretrainedConfig): attend_text_every_n_blocks: int = 2 diffusion_model_cfg: dict | None = None + vl_self_attention_cfg: dict | None = None # Flow matching parameters num_inference_timesteps: int = 4 diff --git a/vllm_omni/diffusion/models/gr00t/dataio/collator/__init__.py b/vllm_omni/diffusion/models/gr00t/dataio/collator/__init__.py deleted file mode 100644 index b39b02e3042..00000000000 --- a/vllm_omni/diffusion/models/gr00t/dataio/collator/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from vllm_omni.diffusion.models.gr00t.dataio.collator.collators import BasicDataCollator - -__all__ = ["BasicDataCollator"] diff --git a/vllm_omni/diffusion/models/gr00t/dataio/collator/collators.py b/vllm_omni/diffusion/models/gr00t/dataio/collator/collators.py deleted file mode 100755 index 17a4c2ad474..00000000000 --- a/vllm_omni/diffusion/models/gr00t/dataio/collator/collators.py +++ /dev/null @@ -1,27 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any - -import torch - - -class BasicDataCollator: - def __call__(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]: - fields = features[0].keys() - batch = {} - for key in fields: - batch[key] = torch.stack([item[key] for item in features]) - return batch diff --git a/vllm_omni/diffusion/models/gr00t/dataio/interfaces.py b/vllm_omni/diffusion/models/gr00t/dataio/interfaces.py deleted file mode 100644 index f2d263600f4..00000000000 --- a/vllm_omni/diffusion/models/gr00t/dataio/interfaces.py +++ /dev/null @@ -1,105 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from typing import Any - -import numpy as np -from transformers import ProcessorMixin - -from vllm_omni.diffusion.models.gr00t.dataio.types import EmbodimentTag, ModalityConfig - - -class BaseProcessor(ProcessorMixin): - def __call__(self, messages: list[dict[str, Any]]) -> dict[str, Any]: - """ - Process a list of messages and return a dictionary of model inputs. - - Args: - messages (list[dict[str, Any]]): List of messages to process. - - Returns: - dict[str, Any]: Dictionary of model inputs. - - Example: - >>> processor = BaseProcessor() - >>> messages = [ - >>> {"type": MessageType.START_OF_EPISODE.value, "content": ""}, - >>> {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, - >>> {"type": MessageType.TEXT.value, "role" : "user", "content": "Please give me the apple"}, - >>> {"type": MessageType.TEXT.value, "role" : "assistant", - >>> "content": "I need to move my left hand to get the apple"}, - >>> {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, - >>> {"type": MessageType.EPISODE_STEP.value, "content": VLAStepData}, - >>> {"type": MessageType.END_OF_EPISODE.value, "content": ""}, - >>> ] - >>> model_input = processor(messages) - >>> print(model_input) - """ - raise NotImplementedError("Subclasses must implement __call__") - - def decode_action( - self, - action: np.ndarray, - embodiment_tag: EmbodimentTag, - state: dict[str, np.ndarray] | None = None, - ) -> dict[str, np.ndarray]: - """Decode the action from the model output.""" - raise NotImplementedError("Subclasses must implement decode_action") - - @property - def collator(self): - raise NotImplementedError("Subclasses must implement collator") - - @abstractmethod - def set_statistics(self, statistics: dict[str, Any], override: bool = False) -> None: - """Set normalization statistics.""" - pass - - def get_modality_configs(self) -> dict[str, dict[str, ModalityConfig]]: - """Get the modality configurations. - - Returns: - dict[str, dict[str, ModalityConfig]]: The modality configurations, where - modality_configs[embodiment_tag][modality] = ModalityConfig - """ - return getattr(self, "modality_configs") - - -class ShardedDataset(ABC): - def __init__(self, dataset_path): - self.dataset_path = dataset_path - - @abstractmethod - def __len__(self) -> int: - """Return the number of shards.""" - pass - - @abstractmethod - def get_shard_length(self, idx: int) -> int: - """Get the length of the shard at index idx.""" - pass - - @abstractmethod - def get_shard(self, idx: int) -> list: - """Get the shard at index idx.""" - pass - - def set_processor(self, processor: BaseProcessor): - self.processor = processor - - def get_dataset_statistics(self) -> dict[str, Any]: - """Get dataset statistics.""" - raise NotImplementedError() diff --git a/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py index 4ed48841022..5568cff9721 100644 --- a/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py +++ b/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py @@ -14,13 +14,16 @@ # limitations under the License. import logging +import os from typing import Any import torch from torch import nn from transformers import AutoConfig, AutoModel, PreTrainedModel from transformers.feature_extraction_utils import BatchFeature -from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig +from vllm.vllm_flash_attn import FA2_AVAILABLE as _FA2_AVAILABLE +from vllm.vllm_flash_attn import FA3_AVAILABLE as _FA3_AVAILABLE +from vllm.vllm_flash_attn import is_fa_version_supported as _is_fa_version_supported from vllm_omni.diffusion.models.gr00t.configs.model.gr00t_n1d7 import Gr00tN1d7Config from vllm_omni.diffusion.models.gr00t.modeling.modules.dit import AlternateVLDiT, DiT, SelfAttentionTransformer @@ -33,59 +36,6 @@ logger = logging.getLogger(__name__) -def _make_qwen3_vl_2b_config( - *, - backbone_embedding_dim: int, - num_hidden_layers: int, - attn_implementation: str, -) -> Qwen3VLConfig: - if backbone_embedding_dim != 2048: - raise ValueError(f"GR00T N1.7 expects a 2048-dim Qwen3-VL backbone, got {backbone_embedding_dim}") - - config = Qwen3VLConfig( - text_config={ - "vocab_size": 151936, - "hidden_size": backbone_embedding_dim, - "intermediate_size": 6144, - "num_hidden_layers": num_hidden_layers, - "num_attention_heads": 16, - "num_key_value_heads": 8, - "head_dim": 128, - "max_position_embeddings": 262144, - "rope_theta": 5000000, - "rope_scaling": { - "mrope_interleaved": True, - "mrope_section": [24, 20, 20], - "rope_type": "default", - }, - "tie_word_embeddings": True, - "bos_token_id": 151643, - "eos_token_id": 151645, - "dtype": "bfloat16", - }, - vision_config={ - "depth": 24, - "hidden_size": 1024, - "intermediate_size": 4096, - "num_heads": 16, - "out_hidden_size": backbone_embedding_dim, - "deepstack_visual_indexes": [5, 11, 17], - "patch_size": 16, - "spatial_merge_size": 2, - "temporal_patch_size": 2, - "num_position_embeddings": 2304, - "in_channels": 3, - }, - image_token_id=151655, - video_token_id=151656, - vision_start_token_id=151652, - vision_end_token_id=151653, - tie_word_embeddings=False, - ) - config._attn_implementation = attn_implementation - return config - - class Gr00tN1d7ActionHead(nn.Module): """Action head component for flow matching diffusion policy.""" @@ -134,7 +84,7 @@ def __init__(self, config: Gr00tN1d7Config): self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity() - vl_self_attention_cfg = getattr(config, "vl_self_attention_cfg", None) + vl_self_attention_cfg = config.vl_self_attention_cfg if vl_self_attention_cfg and vl_self_attention_cfg.get("num_layers", 0) > 0: self.vl_self_attention = SelfAttentionTransformer(**vl_self_attention_cfg) else: @@ -212,11 +162,21 @@ def get_action_with_features( # Set initial actions as the sampled noise. batch_size = vl_embeds.shape[0] device = vl_embeds.device - actions = torch.randn( - size=(batch_size, self.config.action_horizon, self.action_dim), - dtype=vl_embeds.dtype, - device=device, - ) + _seed_env = os.environ.get("GR00T_NOISE_SEED") + if _seed_env is not None: + _gen = torch.Generator(device=device).manual_seed(int(_seed_env)) + actions = torch.randn( + size=(batch_size, self.config.action_horizon, self.action_dim), + dtype=vl_embeds.dtype, + device=device, + generator=_gen, + ) + else: + actions = torch.randn( + size=(batch_size, self.config.action_horizon, self.action_dim), + dtype=vl_embeds.dtype, + device=device, + ) dt = 1.0 / self.num_inference_timesteps vel_strength = torch.ones_like(actions) @@ -364,29 +324,30 @@ def __init__( transformers_loading_kwargs: dict[str, Any] | None = None, ): super().__init__() - del model_name, reproject_vision, transformers_loading_kwargs + del reproject_vision if use_flash_attention: - try: - import flash_attn # noqa: F401 - + if _FA3_AVAILABLE and _is_fa_version_supported(3): + attn_implementation = "flash_attention_3" + elif _FA2_AVAILABLE and _is_fa_version_supported(2): attn_implementation = "flash_attention_2" - except ImportError: - logger.warning( - "flash_attn is not installed. Falling back to sdpa attention. " - "Install flash-attn for better performance: pip install flash-attn" - ) + else: + logger.warning("No supported flash attention backend on this device, falling back to sdpa.") attn_implementation = "sdpa" else: attn_implementation = "sdpa" - num_hidden_layers = select_layer if select_layer >= 0 else 28 - backbone_config = _make_qwen3_vl_2b_config( - backbone_embedding_dim=backbone_embedding_dim, - num_hidden_layers=num_hidden_layers, - attn_implementation=attn_implementation, + backbone_config = AutoConfig.from_pretrained( + model_name, **(transformers_loading_kwargs or {"trust_remote_code": True}) ) self.model = Qwen3VLForConditionalGeneration(backbone_config).eval() + # Set attention implementation post-init — avoids transformers' init-time flash_attn check. + # Text attention layers go through adapter_qwen3_vl which calls vllm_flash_attn directly. + self.model.config.text_config._attn_implementation = attn_implementation + self.model.model.language_model.config._attn_implementation = attn_implementation + # Vision model uses unpatched transformers attention — sdpa is the safe fallback + # (flash_attn and flash_attn_interface packages are absent; only fa3_fwd_interface is present). + self.model.config.vision_config._attn_implementation = "sdpa" if load_bf16: self.model.to(dtype=torch.bfloat16) @@ -493,24 +454,11 @@ def all_tied_weights_keys(self) -> dict[str, Any]: def __init__( self, config: Gr00tN1d7Config, - transformers_loading_kwargs: dict = {"trust_remote_code": True}, + transformers_loading_kwargs: dict | None = None, ): - """ - Initialize Gr00tN1d7 model. - - Args: - config: Model configuration - transformers_loading_kwargs: Dict with transformers loading parameters: - - transformers_trust_remote_code: Whether to trust remote code when loading from HF Hub - - transformers_local_files_only: Whether to only use local files - - model_revision: Specific model revision to use - - transformers_cache_dir: Directory to cache downloaded models - - transformers_access_token: HuggingFace access token for gated models - - Note: During training, transformers parameters are passed from training config. - During inference (e.g., from_pretrained), defaults are used. - """ super().__init__(config) + if transformers_loading_kwargs is None: + transformers_loading_kwargs = {"trust_remote_code": True} self.config = config backbone_cls = get_backbone_cls(config) @@ -541,7 +489,6 @@ def prepare_input(self, inputs: dict) -> tuple[BatchFeature, BatchFeature]: """Prepare inputs for backbone and action head.""" # NOTE -- currently the eval code doesn't use collator, so we need to add it here - # this should ideally be fixed upstream if "vlm_content" in inputs: # Fix for n_envs > 1: Process all environments' VLM content, not just the first vlm_content_list = inputs["vlm_content"] diff --git a/vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py b/vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py old mode 100755 new mode 100644 index 3aa25e81c52..db424adab3b --- a/vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py +++ b/vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py @@ -13,9 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -from contextlib import nullcontext - import torch import torch.nn.functional as F from diffusers import ConfigMixin, ModelMixin @@ -25,51 +22,6 @@ from torch import nn -def _is_spark_sm121() -> bool: - if not torch.cuda.is_available(): - return False - - major, minor = torch.cuda.get_device_capability() - return (major, minor) == (12, 1) - - -def _should_force_math_sdpa() -> bool: - override = os.environ.get("GR00T_DIT_SDPA_MODE") - if override == "math": - return True - if override == "default": - return False - - return _is_spark_sm121() - - -def _sdpa_context(): - # Spark (sm121) currently hits noisy/broken PyTorch mem-efficient SDPA kernel dispatch. - # Force the safe math backend there; on every other platform this returns a no-op context. - if not _should_force_math_sdpa(): - return nullcontext() - - return torch.backends.cuda.sdp_kernel( - enable_flash=False, - enable_math=True, - enable_mem_efficient=False, - enable_cudnn=False, - ) - - -class TimestepEncoder(nn.Module): - def __init__(self, embedding_dim, compute_dtype=torch.float32): - super().__init__() - self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - - def forward(self, timesteps): - dtype = next(self.parameters()).dtype - timesteps_proj = self.time_proj(timesteps).to(dtype) - timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D) - return timesteps_emb - - class AdaLayerNorm(nn.Module): def __init__( self, @@ -108,8 +60,6 @@ def __init__( attention_bias: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, - # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', - # 'ada_norm_continuous', 'layer_norm_i2vgen' norm_type: str = "layer_norm", norm_eps: float = 1e-5, final_dropout: bool = False, @@ -121,16 +71,6 @@ def __init__( attention_out_bias: bool = True, ): super().__init__() - self.dim = dim - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - self.dropout = dropout - self.cross_attention_dim = cross_attention_dim - self.activation_fn = activation_fn - self.attention_bias = attention_bias - self.norm_elementwise_affine = norm_elementwise_affine - self.positional_embeddings = positional_embeddings - self.num_positional_embeddings = num_positional_embeddings self.norm_type = norm_type if positional_embeddings and (num_positional_embeddings is None): @@ -143,8 +83,6 @@ def __init__( else: self.pos_embed = None - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn if norm_type == "ada_norm": self.norm1 = AdaLayerNorm(dim) else: @@ -161,7 +99,6 @@ def __init__( out_bias=attention_out_bias, ) - # 3. Feed-forward self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) self.ff = FeedForward( dim, @@ -171,10 +108,7 @@ def __init__( inner_dim=ff_inner_dim, bias=ff_bias, ) - if final_dropout: - self.final_dropout = nn.Dropout(dropout) - else: - self.final_dropout = None + self.final_dropout = nn.Dropout(dropout) if final_dropout else None def forward( self, @@ -184,7 +118,6 @@ def forward( encoder_attention_mask: torch.Tensor | None = None, temb: torch.LongTensor | None = None, ) -> torch.Tensor: - # 0. Self-Attention if self.norm_type == "ada_norm": norm_hidden_states = self.norm1(hidden_states, temb) else: @@ -193,12 +126,11 @@ def forward( if self.pos_embed is not None: norm_hidden_states = self.pos_embed(norm_hidden_states) - with _sdpa_context(): - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=(encoder_attention_mask if encoder_hidden_states is not None else attention_mask), - ) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=(encoder_attention_mask if encoder_hidden_states is not None else attention_mask), + ) if self.final_dropout: attn_output = self.final_dropout(attn_output) @@ -206,7 +138,6 @@ def forward( if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) - # 4. Feed-forward norm_hidden_states = self.norm3(hidden_states) ff_output = self.ff(norm_hidden_states) @@ -216,6 +147,17 @@ def forward( return hidden_states +class TimestepEncoder(nn.Module): + def __init__(self, embedding_dim: int): + super().__init__() + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + dtype = next(self.parameters()).dtype + return self.timestep_embedder(self.time_proj(timesteps).to(dtype)) + + class DiT(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @@ -243,108 +185,78 @@ def __init__( ): super().__init__() - self.attention_head_dim = attention_head_dim - self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.inner_dim = num_attention_heads * attention_head_dim self.gradient_checkpointing = False - - # Timestep encoder - self.timestep_encoder = TimestepEncoder(embedding_dim=self.inner_dim, compute_dtype=self.compute_dtype) + self.timestep_encoder = TimestepEncoder(embedding_dim=self.inner_dim) all_blocks = [] - for idx in range(self.config.num_layers): + for idx in range(num_layers): use_self_attn = idx % 2 == 1 and interleave_self_attention curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None - - all_blocks += [ + all_blocks.append( BasicTransformerBlock( self.inner_dim, - self.config.num_attention_heads, - self.config.attention_head_dim, - dropout=self.config.dropout, - activation_fn=self.config.activation_fn, - attention_bias=self.config.attention_bias, - upcast_attention=self.config.upcast_attention, + num_attention_heads, + attention_head_dim, + dropout=dropout, + activation_fn=activation_fn, + attention_bias=attention_bias, + upcast_attention=upcast_attention, norm_type=norm_type, - norm_elementwise_affine=self.config.norm_elementwise_affine, - norm_eps=self.config.norm_eps, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, positional_embeddings=positional_embeddings, - num_positional_embeddings=self.config.max_num_positional_embeddings, + num_positional_embeddings=max_num_positional_embeddings, final_dropout=final_dropout, cross_attention_dim=curr_cross_attention_dim, ) - ] + ) self.transformer_blocks = nn.ModuleList(all_blocks) - - # Output blocks self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) - self.proj_out_2 = nn.Linear(self.inner_dim, self.output_dim) - print( - "Total number of DiT parameters: ", - sum(p.numel() for p in self.parameters() if p.requires_grad), - ) + self.proj_out_2 = nn.Linear(self.inner_dim, output_dim) def forward( self, - hidden_states: torch.Tensor, # Shape: (B, T, D) - encoder_hidden_states: torch.Tensor, # Shape: (B, S, D) + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor | None = None, encoder_attention_mask: torch.Tensor | None = None, return_all_hidden_states: bool = False, ): - # Encode timesteps temb = self.timestep_encoder(timestep) - - # Process through transformer blocks - single pass through the blocks hidden_states = hidden_states.contiguous() encoder_hidden_states = encoder_hidden_states.contiguous() - all_hidden_states = [hidden_states] - # Process through transformer blocks for idx, block in enumerate(self.transformer_blocks): if idx % 2 == 1 and self.config.interleave_self_attention: - hidden_states = block( - hidden_states, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - temb=temb, - ) + hidden_states = block(hidden_states, temb=temb) else: hidden_states = block( hidden_states, - attention_mask=None, encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=None, + encoder_attention_mask=encoder_attention_mask, temb=temb, ) all_hidden_states.append(hidden_states) - # Output processing - conditioning = temb - shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] if return_all_hidden_states: return self.proj_out_2(hidden_states), all_hidden_states - else: - return self.proj_out_2(hidden_states) + return self.proj_out_2(hidden_states) class AlternateVLDiT(DiT): - """ - Alternate Vision-Language DiT that separates image and non-image tokens - during cross-attention processing. - """ - def __init__(self, *args, attend_text_every_n_blocks: int = 2, **kwargs): super().__init__(*args, **kwargs) self.attend_text_every_n_blocks = attend_text_every_n_blocks def forward( self, - hidden_states: torch.Tensor, # Shape: (B, T, D) - encoder_hidden_states: torch.Tensor, # Shape: (B, S, D) + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor | None = None, encoder_attention_mask: torch.Tensor | None = None, return_all_hidden_states: bool = False, @@ -352,61 +264,38 @@ def forward( backbone_attention_mask: torch.Tensor | None = None, ): assert image_mask is not None, "Image mask is required" + assert self.config.interleave_self_attention, "Interleave self attention must be enabled" - # Encode timesteps temb = self.timestep_encoder(timestep) - - # Process through transformer blocks - single pass through the blocks hidden_states = hidden_states.contiguous() encoder_hidden_states = encoder_hidden_states.contiguous() - # Create attention masks for image and non-image tokens - # image_mask shape: (B, S) where True indicates image tokens - # For attention, we need to invert: False means "don't attend to this token" - image_attention_mask = image_mask & backbone_attention_mask non_image_attention_mask = (~image_mask) & backbone_attention_mask all_hidden_states = [hidden_states] - assert self.config.interleave_self_attention, "Interleave self attention must be enabled" - - # Process through transformer blocks for idx, block in enumerate(self.transformer_blocks): if idx % 2 == 1: - # Self-attention blocks - hidden_states = block( - hidden_states, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - temb=temb, - ) + hidden_states = block(hidden_states, temb=temb) else: - # Cross-attention blocks - alternate between non-image and image tokens - if idx % (2 * self.attend_text_every_n_blocks) == 0: - # Attend to non-image tokens - curr_encoder_attention_mask = non_image_attention_mask - else: - # Attend to image tokens - curr_encoder_attention_mask = image_attention_mask - + curr_encoder_attention_mask = ( + non_image_attention_mask + if idx % (2 * self.attend_text_every_n_blocks) == 0 + else image_attention_mask + ) hidden_states = block( hidden_states, - attention_mask=None, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=curr_encoder_attention_mask, temb=temb, ) all_hidden_states.append(hidden_states) - # Output processing - conditioning = temb - shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] if return_all_hidden_states: return self.proj_out_2(hidden_states), all_hidden_states - else: - return self.proj_out_2(hidden_states) + return self.proj_out_2(hidden_states) class SelfAttentionTransformer(ModelMixin, ConfigMixin): @@ -432,47 +321,36 @@ def __init__( ): super().__init__() - self.attention_head_dim = attention_head_dim - self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.inner_dim = num_attention_heads * attention_head_dim self.gradient_checkpointing = False - self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( self.inner_dim, - self.config.num_attention_heads, - self.config.attention_head_dim, - dropout=self.config.dropout, - activation_fn=self.config.activation_fn, - attention_bias=self.config.attention_bias, - upcast_attention=self.config.upcast_attention, + num_attention_heads, + attention_head_dim, + dropout=dropout, + activation_fn=activation_fn, + attention_bias=attention_bias, + upcast_attention=upcast_attention, positional_embeddings=positional_embeddings, - num_positional_embeddings=self.config.max_num_positional_embeddings, + num_positional_embeddings=max_num_positional_embeddings, final_dropout=final_dropout, ) - for _ in range(self.config.num_layers) + for _ in range(num_layers) ] ) - print( - "Total number of SelfAttentionTransformer parameters: ", - sum(p.numel() for p in self.parameters() if p.requires_grad), - ) def forward( self, - hidden_states: torch.Tensor, # Shape: (B, T, D) + hidden_states: torch.Tensor, return_all_hidden_states: bool = False, ): - # Process through transformer blocks - single pass through the blocks hidden_states = hidden_states.contiguous() all_hidden_states = [hidden_states] - - # Process through transformer blocks - for idx, block in enumerate(self.transformer_blocks): + for block in self.transformer_blocks: hidden_states = block(hidden_states) all_hidden_states.append(hidden_states) - if return_all_hidden_states: return hidden_states, all_hidden_states - else: - return hidden_states + return hidden_states diff --git a/vllm_omni/diffusion/models/gr00t/modeling/modules/embodiment_conditioned_mlp.py b/vllm_omni/diffusion/models/gr00t/modeling/modules/embodiment_conditioned_mlp.py index df9e3cde38d..b740f4eadd0 100644 --- a/vllm_omni/diffusion/models/gr00t/modeling/modules/embodiment_conditioned_mlp.py +++ b/vllm_omni/diffusion/models/gr00t/modeling/modules/embodiment_conditioned_mlp.py @@ -13,47 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math - import torch import torch.nn.functional as F from torch import nn - -def swish(x): - """Swish activation function.""" - return x * torch.sigmoid(x) - - -class SinusoidalPositionalEncoding(nn.Module): - """ - Produces a sinusoidal encoding of shape (B, T, w) - given timesteps of shape (B, T). - """ - - def __init__(self, embedding_dim): - super().__init__() - self.embedding_dim = embedding_dim - - def forward(self, timesteps): - # timesteps: shape (B, T) - # We'll compute sin/cos frequencies across dim T - timesteps = timesteps.float() # ensure float - - B, T = timesteps.shape - device = timesteps.device - - half_dim = self.embedding_dim // 2 - # typical log space frequencies for sinusoidal encoding - exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (math.log(10000.0) / half_dim) - # Expand timesteps to (B, T, 1) then multiply - freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim) - - sin = torch.sin(freqs) - cos = torch.cos(freqs) - enc = torch.cat([sin, cos], dim=-1) # (B, T, w) - - return enc +from vllm_omni.diffusion.utils.flow_matching import SinusoidalPositionalEncoding, swish class CategorySpecificLinear(nn.Module): @@ -125,17 +89,6 @@ def expand_action_dimension(self, old_action_dim, new_action_dim, expand_input=F self.b = nn.Parameter(new_b) -class SmallMLP(nn.Module): - def __init__(self, input_dim, hidden_dim, output_dim): - super().__init__() - self.layer1 = nn.Linear(input_dim, hidden_dim) - self.layer2 = nn.Linear(hidden_dim, output_dim) - - def forward(self, x): - hidden = F.relu(self.layer1(x)) - return self.layer2(hidden) - - class CategorySpecificMLP(nn.Module): """Two-layer MLP with category-specific weights for multi-embodiment support.""" diff --git a/vllm_omni/diffusion/models/gr00t/modeling/modules/flowmatching_modules.py b/vllm_omni/diffusion/models/gr00t/modeling/modules/flowmatching_modules.py index fcc59ab11c3..20f2bc06aed 100644 --- a/vllm_omni/diffusion/models/gr00t/modeling/modules/flowmatching_modules.py +++ b/vllm_omni/diffusion/models/gr00t/modeling/modules/flowmatching_modules.py @@ -13,57 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math - import torch -import torch.nn.functional as F from torch import nn +from vllm_omni.diffusion.utils.flow_matching import SinusoidalPositionalEncoding, swish -def swish(x): - return x * torch.sigmoid(x) - - -class SinusoidalPositionalEncoding(nn.Module): - """ - Produces a sinusoidal encoding of shape (B, T, w) - given timesteps of shape (B, T). - """ - - def __init__(self, embedding_dim): - super().__init__() - self.embedding_dim = embedding_dim - - def forward(self, timesteps): - # timesteps: shape (B, T) - # We'll compute sin/cos frequencies across dim T - timesteps = timesteps.float() # ensure float - - B, T = timesteps.shape - device = timesteps.device - - half_dim = self.embedding_dim // 2 - # typical log space frequencies for sinusoidal encoding - exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (math.log(10000.0) / half_dim) - # Expand timesteps to (B, T, 1) then multiply - freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim) - - sin = torch.sin(freqs) - cos = torch.cos(freqs) - enc = torch.cat([sin, cos], dim=-1) # (B, T, w) - - return enc - - -class SmallMLP(nn.Module): - def __init__(self, input_dim, hidden_dim, output_dim): - super().__init__() - self.layer1 = nn.Linear(input_dim, hidden_dim) - self.layer2 = nn.Linear(hidden_dim, output_dim) - - def forward(self, x): - hidden = F.relu(self.layer1(x)) - return self.layer2(hidden) +__all__ = ["ActionEncoder", "SinusoidalPositionalEncoding", "swish"] class ActionEncoder(nn.Module): diff --git a/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py index e67225fc4a2..6da55a29f72 100644 --- a/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py +++ b/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py @@ -25,14 +25,13 @@ import torch import torchvision.transforms.v2 as transforms from PIL import Image -from transformers import AutoProcessor +from transformers import AutoProcessor, ProcessorMixin, Qwen3VLProcessor from transformers.feature_extraction_utils import BatchFeature from transformers.utils import cached_file from vllm.logger import init_logger from vllm_omni.diffusion.models.gr00t.configs.embodiment.embodiment_configs import ModalityConfig from vllm_omni.diffusion.models.gr00t.dataio.embodiment_tags import EmbodimentTag -from vllm_omni.diffusion.models.gr00t.dataio.interfaces import BaseProcessor from vllm_omni.diffusion.models.gr00t.dataio.state_action.state_action_processor import StateActionProcessor from vllm_omni.diffusion.models.gr00t.dataio.utils import parse_modality_configs, to_json_serializable @@ -74,11 +73,6 @@ def _build_eval_image_transform( ) -try: - from transformers import Qwen3VLProcessor -except ImportError: - Qwen3VLProcessor = None - logger = init_logger(__name__) # Suppress protobuf deprecation warnings @@ -107,10 +101,6 @@ def _build_eval_image_transform( def build_processor(model_name: str, transformers_loading_kwargs: dict) -> Qwen3VLProcessor: - if Qwen3VLProcessor is None: - raise ImportError( - "Qwen3VLProcessor is not available. Please upgrade transformers: pip install transformers>=4.52.0" - ) if model_name == "nvidia/Cosmos-Reason2-2B": # Cosmos-Reason2-2B lacks a Qwen3VLProcessor; fall back to upstream artifacts. logger.warning_once( @@ -176,7 +166,7 @@ def __str__(self): return f"Gr00tN1d7DataCollator(model_name={self.model_name}, model_type={self.model_type})" -class Gr00tN1d7Processor(BaseProcessor): +class Gr00tN1d7Processor(ProcessorMixin): data_collator_class = Gr00tN1d7DataCollator def __init__( diff --git a/vllm_omni/diffusion/models/gr00t/pipeline_gr00t.py b/vllm_omni/diffusion/models/gr00t/pipeline_gr00t.py index dd8c458d2dc..63f8e4ce83e 100644 --- a/vllm_omni/diffusion/models/gr00t/pipeline_gr00t.py +++ b/vllm_omni/diffusion/models/gr00t/pipeline_gr00t.py @@ -11,20 +11,11 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.models.gr00t.policy import Gr00tPolicy -from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.request import DUMMY_DIFFUSION_REQUEST_ID, OmniDiffusionRequest logger = init_logger(__name__) -def get_gr00t_n1d7_post_process_func(od_config: OmniDiffusionConfig): - del od_config - - def post_process_func(x): - return x - - return post_process_func - - def _to_float32_action_dict(actions: Mapping[str, Any]) -> dict[str, np.ndarray]: converted = {str(key): np.asarray(value, dtype=np.float32) for key, value in actions.items()} if not converted: @@ -32,10 +23,6 @@ def _to_float32_action_dict(actions: Mapping[str, Any]) -> dict[str, np.ndarray] return converted -def _default_device() -> str: - return "cuda" if torch.cuda.is_available() else "cpu" - - class Gr00tN1d7Pipeline(nn.Module): """GR00T N1.7 policy pipeline backed by vLLM-Omni's local GR00T port. @@ -48,18 +35,11 @@ class Gr00tN1d7Pipeline(nn.Module): def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = "") -> None: super().__init__() - self.od_config = od_config - self.prefix = prefix + model_config = od_config.model_config self.model_path = od_config.model - self.model_config = dict(od_config.model_config or {}) - custom_args = od_config.custom_pipeline_args or {} - - default_embodiment = "OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT" - self.embodiment_tag = str( - custom_args.get("embodiment_tag") or self.model_config.get("embodiment_tag") or default_embodiment - ) - self.strict = bool(custom_args.get("strict", self.model_config.get("strict", True))) - self.device = str(custom_args.get("device") or self.model_config.get("device") or _default_device()) + self.embodiment_tag = str(model_config.get("embodiment_tag") or "OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT") + self.strict = bool(model_config.get("strict", True)) + self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info("Loading GR00T N1.7 policy from %s with embodiment_tag=%s", self.model_path, self.embodiment_tag) self.policy = Gr00tPolicy( @@ -70,19 +50,19 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = "") -> None: ) def reset(self) -> dict[str, Any]: - reset = getattr(self.policy, "reset", None) - if callable(reset): - info = reset() - return info or {} - return {} + return self.policy.reset() or {} @property def weights_sources(self) -> tuple[Any, ...]: return () def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - for _ in weights: - pass + consumed = list(weights) + if consumed: + raise RuntimeError( + f"Gr00tN1d7Pipeline.load_weights received {len(consumed)} weight tensors; " + "weights_sources=() should prevent this. GR00T weights are loaded directly by Gr00tPolicy." + ) return set() def _dummy_actions(self) -> dict[str, np.ndarray]: @@ -100,10 +80,10 @@ def _dummy_actions(self) -> dict[str, np.ndarray]: @torch.inference_mode() def forward(self, req: OmniDiffusionRequest, **kwargs) -> DiffusionOutput: del kwargs - extra_args = getattr(req.sampling_params, "extra_args", {}) or {} + extra_args = req.sampling_params.extra_args or {} robot_obs = extra_args.get("robot_obs") if robot_obs is None: - if getattr(req, "request_ids", None) == ["dummy_req_id"]: + if req.request_id == DUMMY_DIFFUSION_REQUEST_ID: return DiffusionOutput(multimodal_output={"actions": self._dummy_actions()}) return DiffusionOutput(error="Gr00tN1d7Pipeline.forward expects sampling_params.extra_args['robot_obs'].") if not isinstance(robot_obs, Mapping): diff --git a/vllm_omni/diffusion/models/gr00t/policy.py b/vllm_omni/diffusion/models/gr00t/policy.py index 9ab3f63e7fe..738988cae00 100644 --- a/vllm_omni/diffusion/models/gr00t/policy.py +++ b/vllm_omni/diffusion/models/gr00t/policy.py @@ -13,13 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Gr00t Policy implementation for inference. - -This module provides the core policy classes for running Gr00t models: -- Gr00tPolicy: Base policy class for model inference -- Gr00tSimPolicyWrapper: Wrapper for compatibility with existing Gr00t simulation environments -""" - from pathlib import Path from typing import Any @@ -28,9 +21,7 @@ from transformers import AutoModel, AutoProcessor from vllm_omni.diffusion.models.gr00t.dataio.embodiment_tags import FINETUNE_ONLY_TAGS, POSTTRAIN_TAGS, EmbodimentTag -from vllm_omni.diffusion.models.gr00t.dataio.interfaces import BaseProcessor from vllm_omni.diffusion.models.gr00t.dataio.types import MessageType, ModalityConfig, VLAStepData -from vllm_omni.diffusion.models.gr00t.policy_base import BasePolicy, PolicyWrapper def _rec_to_dtype(value: Any, dtype: torch.dtype) -> Any: @@ -44,7 +35,7 @@ def _rec_to_dtype(value: Any, dtype: torch.dtype) -> Any: return value -class Gr00tPolicy(BasePolicy): +class Gr00tPolicy: """Core policy class for Gr00t model inference. This policy handles the end-to-end inference pipeline: @@ -78,7 +69,7 @@ def __init__( from vllm_omni.diffusion.models.gr00t.modeling.gr00t_n1d7 import Gr00tN1d7 # noqa: F401 from vllm_omni.diffusion.models.gr00t.modeling.processing_gr00t_n1d7 import Gr00tN1d7Processor # noqa: F401 - super().__init__(strict=strict) + self.strict = strict if isinstance(embodiment_tag, str): embodiment_tag = EmbodimentTag.resolve(embodiment_tag) model_dir = Path(model_path) @@ -98,11 +89,11 @@ def __init__( if (model_dir / "processor").is_dir() and not (model_dir / "processor_config.json").exists() else model_dir ) - self.processor: BaseProcessor = AutoProcessor.from_pretrained(processor_dir) + self.processor = AutoProcessor.from_pretrained(processor_dir) # Store embodiment-specific configurations self.embodiment_tag = embodiment_tag - all_modality_configs = self.processor.get_modality_configs() + all_modality_configs = self.processor.modality_configs if self.embodiment_tag.value not in all_modality_configs: # Map raw checkpoint tag values to user-friendly enum names where possible. supported_lines = [] @@ -437,6 +428,16 @@ def check_action(self, action: dict[str, Any]) -> None: f"{len(self.modality_configs['action'].delta_indices)}. Got {action_arr.shape[1]}" ) + def get_action( + self, observation: dict[str, Any], options: dict[str, Any] | None = None + ) -> tuple[dict[str, Any], dict[str, Any]]: + if self.strict: + self.check_observation(observation) + action, info = self._get_action(observation, options) + if self.strict: + self.check_action(action) + return action, info + def get_modality_config(self) -> dict[str, ModalityConfig]: return self.modality_configs @@ -450,251 +451,3 @@ def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]: Dictionary containing the info after resetting the policy """ return {} - - -class Gr00tSimPolicyWrapper(PolicyWrapper): - """Wrapper for Gr00tPolicy to enable compatibility with existing Gr00t simulation environments. - - This wrapper is specifically designed for retro-fitting the Gr00t policy with the current - Gr00t simulation environment interface. It handles the transformation between the flat - observation format used by Gr00t sim environments (with keys like 'video.camera_name', - 'state.joint_positions') and the nested format expected by Gr00tPolicy. - - **Important**: If you are using other environments, custom robots, or building new environments, - you should use `Gr00tPolicy` directly and format your observations according to its interface. - This wrapper is only needed for compatibility with the existing Gr00t sim infrastructure. - - Key transformations performed by this wrapper: - - Observation keys: 'video.cam' -> observation['video']['cam'] - - Observation keys: 'state.joints' -> observation['state']['joints'] - - Language keys: 'task' or 'annotation.human.coarse_action' -> observation['language']['task'] - - Action keys: action['joints'] -> 'action.joints' - """ - - def __init__(self, policy: Gr00tPolicy, *, strict: bool = True): - """Initialize the wrapper around a Gr00tPolicy instance. - - Args: - policy: The Gr00tPolicy instance to wrap - strict: Whether to enforce strict validation (default: True) - """ - super().__init__(policy, strict=strict) - self.policy: Gr00tPolicy = policy - assert len(self.policy.modality_configs["language"].delta_indices) == 1, ( - "Only one language delta index is supported" - ) - - def check_observation(self, observation: dict[str, Any]) -> None: - """Validate observation from Gr00t sim environment format. - - This validation is specific to the flat observation format used by Gr00t sim environments. - Unlike Gr00tPolicy.check_observation which expects nested dicts, this expects flat keys. - - Expected observation structure (Gr00t sim format): - - Flat keys like 'video.camera_name': np.ndarray[np.uint8, (B, T, H, W, C)] - - Flat keys like 'state.state_name': np.ndarray[np.float32, (B, T, D)] - - Language keys: tuple[str] or list[str] with shape (B,) - - Key can be 'task' or 'annotation.human.coarse_action' (for DC envs) - - Args: - observation: Flat observation dictionary from Gr00t sim environment - - Raises: - AssertionError: If any validation check fails - """ - modality_configs = self.get_modality_config() - - for video_key in modality_configs["video"].modality_keys: - # Construct flat key expected in Gr00t sim environment - parsed_key = f"video.{video_key}" - assert parsed_key in observation, f"Video key '{parsed_key}' must be in observation" - - batched_video = observation[parsed_key] - - # Verify data type is numpy array - assert isinstance(batched_video, np.ndarray), ( - f"Video key '{video_key}' must be a numpy array. Got {type(batched_video)}" - ) - - # Verify dtype is uint8 (standard for image data, range 0-255) - assert batched_video.dtype == np.uint8, ( - f"Video key '{video_key}' must be a numpy array of type np.uint8. Got {batched_video.dtype}" - ) - - # Verify shape has 5 dimensions: (B, T, H, W, C) - assert batched_video.ndim == 5, ( - f"Video key '{video_key}' must be a numpy array of shape (B, T, H, W, C), got {batched_video.shape}" - ) - - # Verify temporal dimension matches the expected horizon from config - assert batched_video.shape[1] == len(modality_configs["video"].delta_indices), ( - f"Video key '{video_key}'s horizon must be " - f"{len(modality_configs['video'].delta_indices)}. Got {batched_video.shape[1]}" - ) - - # Verify channel dimension is 3 (RGB images) - assert batched_video.shape[-1] == 3, ( - f"Video key '{video_key}'s channel 'C' must be 3. Got {batched_video.shape[-1]}" - ) - - for state_key in modality_configs["state"].modality_keys: - # Construct flat key expected in Gr00t sim environment - parsed_key = f"state.{state_key}" - assert parsed_key in observation, f"State key '{parsed_key}' must be in observation" - - batched_state = observation[parsed_key] - - # Verify data type is numpy array - assert isinstance(batched_state, np.ndarray), ( - f"State key '{state_key}' must be a numpy array. Got {type(batched_state)}" - ) - - # Verify dtype is float32 (standard for continuous state values) - assert batched_state.dtype == np.float32, ( - f"State key '{state_key}' must be a numpy array of type np.float32. Got {batched_state.dtype}" - ) - - # Verify shape has 3 dimensions: (B, T, D) - assert batched_state.ndim == 3, ( - f"State key '{state_key}' must be a numpy array of shape (B, T, D), got {batched_state.shape}" - ) - - # Verify temporal dimension matches the expected horizon from config - assert batched_state.shape[1] == len(modality_configs["state"].delta_indices), ( - f"State key '{state_key}'s horizon must be " - f"{len(modality_configs['state'].delta_indices)}. Got {batched_state.shape[1]}" - ) - - for language_key in modality_configs["language"].modality_keys: - # DC envs use 'annotation.human.coarse_action' instead of 'task' - if language_key == "task" and "annotation.human.coarse_action" in observation: - language_key = "annotation.human.coarse_action" - - # Check that the expected language key exists - assert language_key in observation, f"Language key '{language_key}' must be in observation" - - # In Gr00t sim format, language is a tuple of strings (B,) - batched_language: tuple[str] | list[str] = observation[language_key] # (B,) - - # Verify outer structure is a tuple (batch dimension) - assert isinstance(batched_language, (tuple, list)), ( - f"Language key '{language_key}' must be a tuple or list. Got {type(batched_language)}" - ) - - # Verify each batch item is a string - assert isinstance(batched_language[0], str), ( - f"Language batch item must be a string. Got {type(batched_language[0])}" - ) - - def _get_action( - self, observation: dict[str, Any], options: dict[str, Any] | None = None - ) -> tuple[dict[str, Any], dict[str, Any]]: - """Transform Gr00t sim observation format and compute actions. - - This method transforms the flat observation format from Gr00t sim environments - into the nested format expected by Gr00tPolicy, computes actions, and transforms - them back to the flat format expected by Gr00t sim environments. - - Input format (Gr00t sim): - - Flat keys: 'video.camera_name', 'state.state_name' - - Language: tuple[str] (B,) - - Output format (Gr00t sim): - - Flat keys: 'action.action_name' - - Args: - observation: Flat observation dictionary from Gr00t sim environment - options: Optional parameters (currently unused) - - Returns: - Tuple of (flat_actions_dict, info_dict) - """ - # Transform flat observation format to nested format expected by Gr00tPolicy - new_obs = {} - for modality in ["video", "state", "language"]: - new_obs[modality] = {} - for key in self.policy.modality_configs[modality].modality_keys: - if modality == "language": - # DC envs use 'annotation.human.coarse_action' instead of 'task' - if key == "task" and "annotation.human.coarse_action" in observation: - parsed_key = "annotation.human.coarse_action" - else: - parsed_key = key - else: - # Construct flat key (e.g., 'video.camera' or 'state.joints') - parsed_key = f"{modality}.{key}" - - arr = observation[parsed_key] - - # Transform to nested format - if modality == "language": - # Convert from tuple[str] or list[str] (B,) to list[list[str]] (B, 1) - # Each element becomes a list with one string for temporal dimension - new_obs[modality][key] = [[str(item)] for item in arr] - else: - # Video and state arrays are already in correct format (B, T, ...) - new_obs[modality][key] = arr - - # Compute actions using the underlying Gr00tPolicy - action, info = self.policy.get_action(new_obs, options) - - # Transform actions back to flat format for Gr00t sim environment - # action['joints'] -> 'action.joints' - return {f"action.{key}": action[key] for key in action}, info - - def check_action(self, action: dict[str, Any]) -> None: - """Validate action in Gr00t sim environment format. - - This validation is specific to the flat action format used by Gr00t sim environments. - Unlike Gr00tPolicy.check_action which expects nested dicts, this expects flat keys. - - Expected action structure (Gr00t sim format): - - Flat keys like 'action.action_name': np.ndarray[np.float32, (B, T, D)] - - B: batch size - - T: action horizon (number of future action steps) - - D: action dimension - - Args: - action: Flat action dictionary for Gr00t sim environment - - Raises: - AssertionError: If any validation check fails - """ - modality_configs = self.get_modality_config() - - # Validate each action key defined in the modality config - for action_key in modality_configs["action"].modality_keys: - # Construct flat key expected in Gr00t sim environment (e.g., 'action.joints') - parsed_key = f"action.{action_key}" - assert parsed_key in action, f"Action key '{parsed_key}' must be in action" - - action_arr = action[parsed_key] - - # Verify data type is numpy array - assert isinstance(action_arr, np.ndarray), ( - f"Action key '{action_key}' must be a numpy array. Got {type(action_arr)}" - ) - - # Verify dtype is float32 (standard for continuous actions) - assert action_arr.dtype == np.float32, ( - f"Action key '{action_key}' must be a numpy array of type np.float32. Got {action_arr.dtype}" - ) - - # Verify shape has 3 dimensions: (B, T, D) - assert action_arr.ndim == 3, ( - f"Action key '{action_key}' must be a numpy array of shape (B, T, D), got {action_arr.shape}" - ) - - # Verify action horizon matches the expected temporal dimension from config - assert action_arr.shape[1] == len(modality_configs["action"].delta_indices), ( - f"Action key '{action_key}'s horizon must be " - f"{len(modality_configs['action'].delta_indices)}. Got {action_arr.shape[1]}" - ) - - def get_modality_config(self) -> dict[str, ModalityConfig]: - """Get the modality configuration from the underlying policy. - - Returns: - Dictionary mapping modality names to their configurations - """ - return self.policy.get_modality_config() diff --git a/vllm_omni/diffusion/models/gr00t/policy_base.py b/vllm_omni/diffusion/models/gr00t/policy_base.py deleted file mode 100644 index bab968f9de3..00000000000 --- a/vllm_omni/diffusion/models/gr00t/policy_base.py +++ /dev/null @@ -1,132 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from typing import Any - - -class BasePolicy(ABC): - """Abstract base class for robotic control policies. - - This class defines the interface that all policies must implement, including - methods for action computation, input/output validation, and state management. - - Subclasses must implement: - - check_observation(): Validate observation format - - check_action(): Validate action format - - _get_action(): Core action computation logic - - reset(): Reset policy to initial state - """ - - def __init__(self, *, strict: bool = True): - self.strict = strict - - @abstractmethod - def check_observation(self, observation: dict[str, Any]) -> None: - """Check if the observation is valid. - - Args: - observation: Dictionary containing the current state/observation of the environment - - Raises: - AssertionError: If the observation is invalid. - """ - pass - - @abstractmethod - def check_action(self, action: dict[str, Any]) -> None: - """Check if the action is valid. - - Args: - action: Dictionary containing the action to be executed - - Raises: - AssertionError: If the action is invalid. - """ - pass - - @abstractmethod - def _get_action( - self, observation: dict[str, Any], options: dict[str, Any] | None = None - ) -> tuple[dict[str, Any], dict[str, Any]]: - """Compute and return the next action based on current observation. - - This method should be overridden by subclasses to implement policy-specific - action computation. Input validation is handled by the public get_action() method. - - Args: - observation: Dictionary containing the current state/observation - options: Optional configuration dict for action computation - - Returns: - Tuple of (action, info): - - action: Dictionary containing the action to be executed - - info: Dictionary containing additional metadata (e.g., confidence scores) - """ - pass - - def get_action( - self, observation: dict[str, Any], options: dict[str, Any] | None = None - ) -> tuple[dict[str, Any], dict[str, Any]]: - """Compute and return the next action based on current observation with validation. - - This is the main public interface. It validates the observation, calls - the internal _get_action(), and validates the resulting action. - - Args: - observation: Dictionary containing the current state/observation - options: Optional configuration dict for action computation - - Returns: - Tuple of (action, info): - - action: Dictionary containing the validated action - - info: Dictionary containing additional metadata - - Raises: - AssertionError/ValueError: If observation or action validation fails - """ - if self.strict: - self.check_observation(observation) - action, info = self._get_action(observation, options) - if self.strict: - self.check_action(action) - return action, info - - @abstractmethod - def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]: - """Reset the policy to its initial state. - - Args: - options: Dictionary containing the options for the reset - - Returns: - Dictionary containing the info after resetting the policy - """ - pass - - -class PolicyWrapper(BasePolicy): - """Base wrapper class for composing policy behaviors. - - Note: This base implementation only forwards reset(). Subclasses should - implement validation logic and additional functionality as needed. - """ - - def __init__(self, policy: BasePolicy, *, strict: bool = True): - super().__init__(strict=strict) - self.policy = policy - - def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]: - return self.policy.reset(options) diff --git a/vllm_omni/diffusion/models/internvla_a1/adapter_qwen3_vl.py b/vllm_omni/diffusion/models/internvla_a1/adapter_qwen3_vl.py index 40a4a111b6b..065596ced98 100644 --- a/vllm_omni/diffusion/models/internvla_a1/adapter_qwen3_vl.py +++ b/vllm_omni/diffusion/models/internvla_a1/adapter_qwen3_vl.py @@ -34,6 +34,50 @@ from transformers.models.qwen3_vl.modeling_qwen3_vl import ( Qwen3VLTextRMSNorm as HFQwen3VLTextRMSNorm, ) +from vllm.vllm_flash_attn import FA2_AVAILABLE as _FA2_AVAILABLE +from vllm.vllm_flash_attn import FA3_AVAILABLE as _FA3_AVAILABLE +from vllm.vllm_flash_attn import flash_attn_varlen_func as _vllm_fa_varlen +from vllm.vllm_flash_attn import is_fa_version_supported as _is_fa_version_supported + +_VLLM_FA3_OK: bool = _FA3_AVAILABLE and _is_fa_version_supported(3) +_VLLM_FA2_OK: bool = _FA2_AVAILABLE and _is_fa_version_supported(2) +_VLLM_FA_AVAILABLE: bool = _VLLM_FA3_OK or _VLLM_FA2_OK +_VLLM_FA_VERSION: int = 3 if _VLLM_FA3_OK else 2 + +_FLASH_IMPL_NAMES = frozenset({"flash_attention_2", "flash_attention_3", "flash_attention_4"}) + + +def _vllm_flash_attn_forward( + module, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask, # ignored — causal handled by flash attn + dropout: float = 0.0, + scaling: float | None = None, + **kwargs, +) -> tuple[torch.Tensor, None]: + batch_size, n_heads, seq_len, head_dim = query_states.shape + n_kv_heads = key_states.shape[1] + # (batch, heads, seq, dim) → (batch*seq, heads, dim) + q = query_states.transpose(1, 2).contiguous().view(batch_size * seq_len, n_heads, head_dim) + k = key_states.transpose(1, 2).contiguous().view(batch_size * seq_len, n_kv_heads, head_dim) + v = value_states.transpose(1, 2).contiguous().view(batch_size * seq_len, n_kv_heads, head_dim) + cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, seq_len, dtype=torch.int32, device=q.device) + out = _vllm_fa_varlen( + q, + k, + v, + max_seqlen_q=seq_len, + cu_seqlens_q=cu_seqlens, + max_seqlen_k=seq_len, + cu_seqlens_k=cu_seqlens, + softmax_scale=scaling, + causal=True, + fa_version=_VLLM_FA_VERSION, + ) + # (batch*seq, heads, dim) → (batch, seq, heads, dim) to match transformers interface + return out.view(batch_size, seq_len, n_heads, head_dim), None class Qwen3VLTextRMSNorm(HFQwen3VLTextRMSNorm): @@ -81,20 +125,31 @@ def forward( key_states = torch.cat([past_key_values[self.layer_idx][0], key_states], dim=2) value_states = torch.cat([past_key_values[self.layer_idx][1], value_states], dim=2) - attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface( - self.config._attn_implementation, eager_attention_forward - ) - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, - ) + if _VLLM_FA_AVAILABLE and self.config._attn_implementation in _FLASH_IMPL_NAMES: + attn_output, attn_weights = _vllm_flash_attn_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -223,7 +278,6 @@ def __init__(self, config): class Qwen3VLForConditionalGeneration(HFQwen3VLForConditionalGeneration): _checkpoint_conversion_mapping = {} - _tied_weights_keys = ["lm_head.weight"] accepts_loss_kwargs = False config: Qwen3VLConfig diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 6ce70dd2b3a..bbae141116b 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -478,7 +478,6 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "BagelPipeline": "get_bagel_post_process_func", "MingImagePipeline": "get_ming_image_post_process_func", "InternVLAA1Pipeline": "get_internvla_a1_post_process_func", - "Gr00tN1d7Pipeline": "get_gr00t_n1d7_post_process_func", "LongCatImageEditPipeline": "get_longcat_image_post_process_func", "StableDiffusion3Pipeline": "get_sd3_image_post_process_func", "FluxKontextPipeline": "get_flux_kontext_post_process_func", diff --git a/vllm_omni/diffusion/utils/flow_matching.py b/vllm_omni/diffusion/utils/flow_matching.py new file mode 100644 index 00000000000..e81d405ec23 --- /dev/null +++ b/vllm_omni/diffusion/utils/flow_matching.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math + +import torch +from torch import nn + + +def swish(x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(x) + + +class SinusoidalPositionalEncoding(nn.Module): + def __init__(self, embedding_dim: int) -> None: + super().__init__() + self.embedding_dim = embedding_dim + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + timesteps = timesteps.float() + B, T = timesteps.shape + device = timesteps.device + + half_dim = self.embedding_dim // 2 + exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (math.log(10000.0) / half_dim) + freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim) + enc = torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1) # (B, T, embedding_dim) + return enc From 257ac6f6c50f73126a6cb3388214f53fce2f2946 Mon Sep 17 00:00:00 2001 From: Zhengyuan Su Date: Mon, 1 Jun 2026 19:40:59 +0700 Subject: [PATCH 6/8] refactor(gr00t): trim training infra and flatten configs Signed-off-by: Zhengyuan Su --- .../test_diffusion_engine_actions.py | 71 ------------------- .../openai_api/test_openpi_serving.py | 10 --- .../models/gr00t/configs/__init__.py | 6 ++ .../gr00t/configs/embodiment/__init__.py | 2 - .../{embodiment => }/embodiment_configs.py | 0 .../gr00t/configs/{model => }/gr00t_n1d7.py | 3 - .../models/gr00t/configs/model/__init__.py | 8 --- .../state_action/state_action_processor.py | 2 +- .../diffusion/models/gr00t/dataio/utils.py | 2 +- .../models/gr00t/modeling/gr00t_n1d7.py | 39 +--------- .../gr00t/modeling/processing_gr00t_n1d7.py | 2 +- 11 files changed, 11 insertions(+), 134 deletions(-) delete mode 100644 tests/diffusion/test_diffusion_engine_actions.py delete mode 100644 vllm_omni/diffusion/models/gr00t/configs/embodiment/__init__.py rename vllm_omni/diffusion/models/gr00t/configs/{embodiment => }/embodiment_configs.py (100%) rename vllm_omni/diffusion/models/gr00t/configs/{model => }/gr00t_n1d7.py (97%) delete mode 100644 vllm_omni/diffusion/models/gr00t/configs/model/__init__.py diff --git a/tests/diffusion/test_diffusion_engine_actions.py b/tests/diffusion/test_diffusion_engine_actions.py deleted file mode 100644 index 8408c417649..00000000000 --- a/tests/diffusion/test_diffusion_engine_actions.py +++ /dev/null @@ -1,71 +0,0 @@ -import asyncio -import importlib.util -import types - -import numpy as np -import pytest - -from vllm_omni.diffusion.data import DiffusionOutput -from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.inputs.data import OmniDiffusionSamplingParams - -HAS_VLLM_IR = importlib.util.find_spec("vllm.ir") is not None -if HAS_VLLM_IR: - from vllm_omni.diffusion.diffusion_engine import DiffusionEngine - -pytestmark = [ - pytest.mark.core_model, - pytest.mark.cpu, - pytest.mark.skipif(not HAS_VLLM_IR, reason="Installed vLLM does not provide vllm.ir"), -] - - -def _engine_with_output(output: DiffusionOutput): - engine = DiffusionEngine.__new__(DiffusionEngine) - engine.od_config = types.SimpleNamespace(enable_cpu_offload=False, model_class_name="Gr00tN1d7Pipeline") - engine.pre_process_func = None - engine.post_process_func = None - engine._post_process_accepts_sampling_params = False - - async def check_loop(): - return None - - async def run_request(request): - return output - - engine._check_and_start_background_loop = check_loop - engine.async_add_req_and_wait_for_response = run_request - return engine - - -def test_diffusion_engine_surfaces_action_multimodal_output(): - actions = {"arm": np.array([[[1.0, 2.0]]], dtype=np.float32)} - engine = _engine_with_output(DiffusionOutput(multimodal_output={"actions": actions})) - req = OmniDiffusionRequest( - prompts=["pick"], - request_ids=["req"], - sampling_params=OmniDiffusionSamplingParams(), - ) - - outputs = asyncio.run(engine.step(req)) - - assert len(outputs) == 1 - assert outputs[0].images == [] - assert outputs[0].final_output_type == "actions" - assert outputs[0].multimodal_output["actions"] is actions - - -def test_diffusion_engine_surfaces_actions_from_output_dict(): - actions = {"arm": np.array([[[1.0, 2.0]]], dtype=np.float32)} - engine = _engine_with_output(DiffusionOutput(output={"actions": actions})) - req = OmniDiffusionRequest( - prompts=["pick"], - request_ids=["req"], - sampling_params=OmniDiffusionSamplingParams(), - ) - - outputs = asyncio.run(engine.step(req)) - - assert outputs[0].images == [] - assert outputs[0].final_output_type == "actions" - assert outputs[0].multimodal_output["actions"] is actions diff --git a/tests/entrypoints/openai_api/test_openpi_serving.py b/tests/entrypoints/openai_api/test_openpi_serving.py index 597ba9f3028..9eb1d5bfe0b 100644 --- a/tests/entrypoints/openai_api/test_openpi_serving.py +++ b/tests/entrypoints/openai_api/test_openpi_serving.py @@ -2,7 +2,6 @@ import json import threading from concurrent.futures import ThreadPoolExecutor -from pathlib import Path from types import SimpleNamespace import numpy as np @@ -25,8 +24,6 @@ "action_space": "joint_position", } -GR00T_DEPLOY_CONFIG = Path(__file__).resolve().parents[3] / "vllm_omni" / "deploy" / "Gr00tN1d7.yaml" - def _json_default(obj): if isinstance(obj, np.ndarray): @@ -104,13 +101,6 @@ async def _generate(): return _generate() -def test_gr00t_deploy_reports_droid_image_resolution(): - config = OmegaConf.load(GR00T_DEPLOY_CONFIG) - policy_config = config.stages[0].model_config.policy_server_config - - assert list(policy_config.image_resolution) == [180, 320] - - def test_policy_server_config_reads_diffusion_model_config(): policy_config = { "image_resolution": [64, 64], diff --git a/vllm_omni/diffusion/models/gr00t/configs/__init__.py b/vllm_omni/diffusion/models/gr00t/configs/__init__.py index 208f01a7cb5..9648b859751 100644 --- a/vllm_omni/diffusion/models/gr00t/configs/__init__.py +++ b/vllm_omni/diffusion/models/gr00t/configs/__init__.py @@ -1,2 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +MODEL_CONFIG_TYPES: dict[str, type] = {} + + +def register_model_config(shortname: str, configtype: type) -> None: + MODEL_CONFIG_TYPES[shortname] = configtype diff --git a/vllm_omni/diffusion/models/gr00t/configs/embodiment/__init__.py b/vllm_omni/diffusion/models/gr00t/configs/embodiment/__init__.py deleted file mode 100644 index 208f01a7cb5..00000000000 --- a/vllm_omni/diffusion/models/gr00t/configs/embodiment/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/diffusion/models/gr00t/configs/embodiment/embodiment_configs.py b/vllm_omni/diffusion/models/gr00t/configs/embodiment_configs.py similarity index 100% rename from vllm_omni/diffusion/models/gr00t/configs/embodiment/embodiment_configs.py rename to vllm_omni/diffusion/models/gr00t/configs/embodiment_configs.py diff --git a/vllm_omni/diffusion/models/gr00t/configs/model/gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/configs/gr00t_n1d7.py similarity index 97% rename from vllm_omni/diffusion/models/gr00t/configs/model/gr00t_n1d7.py rename to vllm_omni/diffusion/models/gr00t/configs/gr00t_n1d7.py index c938e06d097..dc08a0cffe3 100644 --- a/vllm_omni/diffusion/models/gr00t/configs/model/gr00t_n1d7.py +++ b/vllm_omni/diffusion/models/gr00t/configs/gr00t_n1d7.py @@ -48,10 +48,7 @@ class Gr00tN1d7Config(PretrainedConfig): model_name: str = "nvidia/Cosmos-Reason2-2B" backbone_model_type: str = "qwen" model_revision: str | None = None - tune_top_llm_layers: int = 0 # Number of top LLM layers to tune backbone_embedding_dim: int = 2048 # project_to_dim; must match Cosmos-Reason2-2B hidden size - tune_llm: bool = False - tune_visual: bool = False select_layer: int = 12 reproject_vision: bool = False use_flash_attention: bool = True diff --git a/vllm_omni/diffusion/models/gr00t/configs/model/__init__.py b/vllm_omni/diffusion/models/gr00t/configs/model/__init__.py deleted file mode 100644 index 9648b859751..00000000000 --- a/vllm_omni/diffusion/models/gr00t/configs/model/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -MODEL_CONFIG_TYPES: dict[str, type] = {} - - -def register_model_config(shortname: str, configtype: type) -> None: - MODEL_CONFIG_TYPES[shortname] = configtype diff --git a/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py b/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py index 3c861c0c259..cce6362e4e5 100644 --- a/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py +++ b/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py @@ -19,7 +19,7 @@ import numpy as np -from vllm_omni.diffusion.models.gr00t.configs.embodiment.embodiment_configs import ( +from vllm_omni.diffusion.models.gr00t.configs.embodiment_configs import ( ActionFormat, ActionRepresentation, ActionType, diff --git a/vllm_omni/diffusion/models/gr00t/dataio/utils.py b/vllm_omni/diffusion/models/gr00t/dataio/utils.py index 97ee660dfff..379b0d87606 100644 --- a/vllm_omni/diffusion/models/gr00t/dataio/utils.py +++ b/vllm_omni/diffusion/models/gr00t/dataio/utils.py @@ -19,7 +19,7 @@ import numpy as np -from vllm_omni.diffusion.models.gr00t.configs.embodiment.embodiment_configs import ModalityConfig +from vllm_omni.diffusion.models.gr00t.configs.embodiment_configs import ModalityConfig def apply_sin_cos_encoding(values: np.ndarray) -> np.ndarray: diff --git a/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py index 5568cff9721..5b14c61c7a7 100644 --- a/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py +++ b/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py @@ -25,7 +25,7 @@ from vllm.vllm_flash_attn import FA3_AVAILABLE as _FA3_AVAILABLE from vllm.vllm_flash_attn import is_fa_version_supported as _is_fa_version_supported -from vllm_omni.diffusion.models.gr00t.configs.model.gr00t_n1d7 import Gr00tN1d7Config +from vllm_omni.diffusion.models.gr00t.configs.gr00t_n1d7 import Gr00tN1d7Config from vllm_omni.diffusion.models.gr00t.modeling.modules.dit import AlternateVLDiT, DiT, SelfAttentionTransformer from vllm_omni.diffusion.models.gr00t.modeling.modules.embodiment_conditioned_mlp import ( CategorySpecificMLP, @@ -313,14 +313,11 @@ class _Qwen3VLBackbone(nn.Module): def __init__( self, model_name: str, - tune_llm: bool, - tune_visual: bool, select_layer: int, reproject_vision: bool, use_flash_attention: bool, backbone_embedding_dim: int, load_bf16: bool, - tune_top_llm_layers: int, transformers_loading_kwargs: dict[str, Any] | None = None, ): super().__init__() @@ -355,37 +352,8 @@ def __init__( while len(self.model.model.language_model.layers) > target_layers: self.model.model.language_model.layers.pop(-1) - self.set_trainable_parameters(tune_llm, tune_visual, tune_top_llm_layers) - - def set_trainable_parameters(self, tune_llm: bool, tune_visual: bool, tune_top_llm_layers: int) -> None: - self.tune_llm = tune_llm - self.tune_visual = tune_visual - for param in self.parameters(): - param.requires_grad = True - if not tune_llm: - self.model.model.language_model.requires_grad_(False) - if not tune_visual: - self.model.model.visual.requires_grad_(False) - - if tune_top_llm_layers > 0: - for layer in self.model.model.language_model.layers[-tune_top_llm_layers:]: - for param in layer.parameters(): - param.requires_grad = True - - logger.debug("Tune backbone llm: %s", self.tune_llm) - logger.debug("Tune backbone visual: %s", self.tune_visual) - for name, param in self.named_parameters(): - if param.requires_grad: - logger.debug("Backbone trainable parameter: %s", name) - if not any(param.requires_grad for param in self.parameters()): - logger.warning("No backbone trainable parameters found.") - def set_frozen_modules_to_eval_mode(self) -> None: - if self.training: - if self.model.model.language_model and not self.tune_llm: - self.model.model.language_model.eval() - if self.model.model.visual and not self.tune_visual: - self.model.model.visual.eval() + self.eval() def prepare_input(self, batch: dict) -> BatchFeature: return BatchFeature(data=batch) @@ -464,14 +432,11 @@ def __init__( backbone_cls = get_backbone_cls(config) self.backbone = backbone_cls( model_name=config.model_name, - tune_llm=config.tune_llm, - tune_visual=config.tune_visual, select_layer=config.select_layer, reproject_vision=config.reproject_vision, use_flash_attention=config.use_flash_attention, backbone_embedding_dim=config.backbone_embedding_dim, load_bf16=config.load_bf16, - tune_top_llm_layers=config.tune_top_llm_layers, transformers_loading_kwargs=transformers_loading_kwargs, ) diff --git a/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py index 6da55a29f72..fb60b9b2afd 100644 --- a/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py +++ b/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py @@ -30,7 +30,7 @@ from transformers.utils import cached_file from vllm.logger import init_logger -from vllm_omni.diffusion.models.gr00t.configs.embodiment.embodiment_configs import ModalityConfig +from vllm_omni.diffusion.models.gr00t.configs.embodiment_configs import ModalityConfig from vllm_omni.diffusion.models.gr00t.dataio.embodiment_tags import EmbodimentTag from vllm_omni.diffusion.models.gr00t.dataio.state_action.state_action_processor import StateActionProcessor from vllm_omni.diffusion.models.gr00t.dataio.utils import parse_modality_configs, to_json_serializable From eb3b6a968180929a201dadb13705e85bcc06e43a Mon Sep 17 00:00:00 2001 From: Zhengyuan Su Date: Mon, 1 Jun 2026 23:02:50 +0700 Subject: [PATCH 7/8] feat(gr00t): clean up GR00T-N1.7 serving code Signed-off-by: Zhengyuan Su --- docs/.nav.yml | 1 - .../examples/online_serving/gr00t.md | 23 -- recipes/NVIDIA/GR00T-N1.7.md | 39 ++- tests/e2e/online_serving/test_gr00t_openpi.py | 6 +- .../gr00t/configs/embodiment_configs.py | 223 +----------------- .../models/gr00t/configs/gr00t_n1d7.py | 171 +++++++------- .../models/gr00t/modeling/gr00t_n1d7.py | 42 +--- vllm_omni/diffusion/models/gr00t/policy.py | 16 +- .../models/internvla_a1/adapter_qwen3_vl.py | 93 ++------ 9 files changed, 140 insertions(+), 474 deletions(-) delete mode 100644 docs/user_guide/examples/online_serving/gr00t.md diff --git a/docs/.nav.yml b/docs/.nav.yml index dc28cb1e49a..562a49e84f8 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -37,7 +37,6 @@ nav: - vLLM-Omni Helm Chart: user_guide/examples/online_serving/chart-helm.md - Diffusers Backend Adapter: user_guide/examples/online_serving/diffusers_pipeline_adapter.md - GLM-Image Online Serving: user_guide/examples/online_serving/glm_image.md - - GR00T OpenPI Serving: user_guide/examples/online_serving/gr00t.md - Image-To-Image: user_guide/examples/online_serving/image_to_image.md - Image-To-Video: user_guide/examples/online_serving/image_to_video.md - Online serving Example of vLLM-Omni for MiMo-Audio: user_guide/examples/online_serving/mimo_audio.md diff --git a/docs/user_guide/examples/online_serving/gr00t.md b/docs/user_guide/examples/online_serving/gr00t.md deleted file mode 100644 index 13e2183dc58..00000000000 --- a/docs/user_guide/examples/online_serving/gr00t.md +++ /dev/null @@ -1,23 +0,0 @@ -# GR00T OpenPI Serving - -GR00T N1.7 is served through `/v1/realtime/robot/openpi`. The endpoint uses the OpenPI msgpack-numpy websocket protocol and returns GR00T actions as `dict[str, np.ndarray]`. - -## Prerequisites - -Install `openpi-client` in the serving environment. The OpenPI endpoint uses `openpi_client.msgpack_numpy` to pack and unpack websocket payloads. - -## Start the server - -```bash -uv run --no-sync --with openpi-client vllm serve nvidia/GR00T-N1.7-3B \ - --omni \ - --stage-configs-path vllm_omni/deploy/Gr00tN1d7.yaml -``` - -The deploy config is `vllm_omni/deploy/Gr00tN1d7.yaml`. It registers `Gr00tN1d7Pipeline` and exposes `policy_server_config` for the OpenPI handshake. - -## Action output - -Unlike single-stream policies that return one ndarray, GR00T returns a per-action-key dictionary. vLLM-Omni preserves that dictionary under `multimodal_output["actions"]`, and the OpenPI endpoint sends it as the websocket success payload. - -See [`recipes/NVIDIA/GR00T-N1.7.md`](../../../../recipes/NVIDIA/GR00T-N1.7.md) for a full serving recipe with hardware requirements and verification steps. diff --git a/recipes/NVIDIA/GR00T-N1.7.md b/recipes/NVIDIA/GR00T-N1.7.md index 61615b18da6..0a314954026 100644 --- a/recipes/NVIDIA/GR00T-N1.7.md +++ b/recipes/NVIDIA/GR00T-N1.7.md @@ -24,27 +24,18 @@ keys (`eef_9d`, `gripper_position`, `joint_position`) with action horizon 40. - OpenPI client library: - Pipeline: `vllm_omni.diffusion.models.gr00t.pipeline_gr00t.Gr00tN1d7Pipeline` - Deploy config: [`vllm_omni/deploy/Gr00tN1d7.yaml`](../../vllm_omni/deploy/Gr00tN1d7.yaml) -- Docs: [`docs/user_guide/examples/online_serving/gr00t.md`](../../docs/user_guide/examples/online_serving/gr00t.md) - E2E test: [`tests/e2e/online_serving/test_gr00t_openpi.py`](../../tests/e2e/online_serving/test_gr00t_openpi.py) -## Hardware Support - -This recipe documents one CUDA GPU serving configuration. GR00T-N1.7-3B fits -on a single 48 GB GPU; smaller GPUs have not been validated. - -## GPU - -### 1x RTX 6000 Ada 48 GB (tested) - -#### Environment +## Environment - OS: Linux - Python: 3.11+ - Driver / runtime: NVIDIA CUDA -- vLLM-Omni version or commit: Use the commit you are deploying from -- Extra Python deps: `pip install openpi-client websockets` +- vLLM-Omni version or commit: use versions from your current checkout + +## Start server -#### Command +From repository root: ```bash vllm serve nvidia/GR00T-N1.7-3B \ @@ -52,15 +43,19 @@ vllm serve nvidia/GR00T-N1.7-3B \ --host 127.0.0.1 \ --port 8000 \ --served-model-name gr00t-n1d7 \ - --stage-configs-path vllm_omni/deploy/Gr00tN1d7.yaml \ - --disable-log-stats + --stage-configs-path vllm_omni/deploy/Gr00tN1d7.yaml ``` -The WebSocket endpoint is `ws://127.0.0.1:8000/v1/realtime/robot/openpi`. The -server handshake message (first frame after connect) is a msgpack-encoded dict -with `action_horizon`, `action_keys`, `embodiment_tag`, and `needs_session_id`. +Notes: -#### Verification +- Only `max_num_seqs: 1` is supported (configured in the deploy YAML); GR00T + policy state is per-session and not designed for concurrent batching. +- The WebSocket endpoint is `ws://127.0.0.1:8000/v1/realtime/robot/openpi`. + The server handshake message (first frame after connect) is a msgpack-encoded + dict with `action_horizon`, `action_keys`, `embodiment_tag`, and + `needs_session_id`. + +## Verification ```python from tests.gr00t.openpi_client_helper import run_policy_session, validate_session_result @@ -80,10 +75,8 @@ The test sends a synthetic two-frame DROID observation and checks: - All action values are finite float32 - Reset response is `"reset successful"` -#### Notes +## Notes -- Only `max_num_seqs: 1` is supported (configured in the deploy YAML); GR00T - policy state is per-session and not designed for concurrent batching. - To switch embodiment, edit `embodiment_tag` under both `model_config` and `policy_server_config` in `vllm_omni/deploy/Gr00tN1d7.yaml`. Supported values: `OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` (default), `XDOF`, `XDOF_SUBTASK`, diff --git a/tests/e2e/online_serving/test_gr00t_openpi.py b/tests/e2e/online_serving/test_gr00t_openpi.py index 0d80d0fbf07..51f1a386cca 100644 --- a/tests/e2e/online_serving/test_gr00t_openpi.py +++ b/tests/e2e/online_serving/test_gr00t_openpi.py @@ -128,11 +128,7 @@ def test_gr00t_n1d7_openpi_online(omni_server) -> None: @hardware_test(res={"cuda": "H100"}, num_cards=1) @pytest.mark.parametrize("omni_server", test_params, indirect=True) def test_gr00t_n1d7_openpi_precision(omni_server) -> None: - """Assert actions match Isaac-GR00T reference (GR00T_NOISE_SEED=42, zero inputs). - - atol=1e-2 covers the ~0.006 max diff from flash-attn 2.7.4 (Isaac-GR00T) vs - vllm.vllm_flash_attn (vLLM-Omni), compounding over 50 denoising steps. - """ + """Assert actions match Isaac-GR00T reference (GR00T_NOISE_SEED=42, zero inputs).""" client = openpi_client.OpenPIWebsocketClient(host=omni_server.host, port=omni_server.port) try: obs = openpi_client.build_droid_observation(session_id="gr00t-precision-e2e") diff --git a/vllm_omni/diffusion/models/gr00t/configs/embodiment_configs.py b/vllm_omni/diffusion/models/gr00t/configs/embodiment_configs.py index 78411b88ec9..5fd683a26e0 100644 --- a/vllm_omni/diffusion/models/gr00t/configs/embodiment_configs.py +++ b/vllm_omni/diffusion/models/gr00t/configs/embodiment_configs.py @@ -21,225 +21,4 @@ ModalityConfig, ) -MODALITY_CONFIGS = { - ##### Pre-registered pretrain configurations ##### - "oxe_droid_relative_eef_relative_joint": { - "video": ModalityConfig( - delta_indices=[-15, 0], - modality_keys=["exterior_image_1_left", "wrist_image_left"], - ), - "state": ModalityConfig( - delta_indices=[0], - modality_keys=["eef_9d", "gripper_position", "joint_position"], - ), - "action": ModalityConfig( - delta_indices=list(range(40)), - modality_keys=["eef_9d", "gripper_position", "joint_position"], - action_configs=[ - ActionConfig( - rep=ActionRepresentation.RELATIVE, - type=ActionType.EEF, - format=ActionFormat.XYZ_ROT6D, - state_key="eef_9d", - ), - ActionConfig( - rep=ActionRepresentation.ABSOLUTE, - type=ActionType.NON_EEF, - format=ActionFormat.DEFAULT, - state_key="gripper_position", - ), - ActionConfig( - rep=ActionRepresentation.RELATIVE, - type=ActionType.NON_EEF, - format=ActionFormat.DEFAULT, - state_key="joint_position", - ), - ], - ), - "language": ModalityConfig( - delta_indices=[0], - modality_keys=["annotation.language.language_instruction"], - ), - }, - ##### Pre-registered posttrain configurations ##### - "unitree_g1_sonic": { - "video": ModalityConfig( - delta_indices=[0], - modality_keys=["ego_view"], - ), - "state": ModalityConfig( - delta_indices=[0], - modality_keys=[ - "left_leg", - "right_leg", - "waist", - "left_arm", - "right_arm", - "left_hand", - "right_hand", - "projected_gravity", - ], - ), - "action": ModalityConfig( - delta_indices=list(range(40)), - modality_keys=[ - "motion_token", - "left_hand_joints", - "right_hand_joints", - ], - action_configs=[ - ActionConfig( - rep=ActionRepresentation.ABSOLUTE, - type=ActionType.NON_EEF, - format=ActionFormat.DEFAULT, - ), - ActionConfig( - rep=ActionRepresentation.ABSOLUTE, - type=ActionType.NON_EEF, - format=ActionFormat.DEFAULT, - ), - ActionConfig( - rep=ActionRepresentation.ABSOLUTE, - type=ActionType.NON_EEF, - format=ActionFormat.DEFAULT, - ), - ], - ), - "language": ModalityConfig( - delta_indices=[0], - modality_keys=["annotation.human.task_description"], - ), - }, - "unitree_g1_full_body_with_waist_height_nav_cmd": { - "video": ModalityConfig( - delta_indices=[0], - modality_keys=["ego_view"], - ), - "state": ModalityConfig( - delta_indices=[0], - modality_keys=[ - "left_leg", - "right_leg", - "waist", - "left_arm", - "right_arm", - "left_hand", - "right_hand", - ], - ), - "action": ModalityConfig( - delta_indices=list(range(50)), - modality_keys=[ - "left_arm", - "right_arm", - "left_hand", - "right_hand", - "waist", - "base_height_command", - "navigate_command", - ], - action_configs=[ - # left_arm - ActionConfig( - rep=ActionRepresentation.RELATIVE, - type=ActionType.NON_EEF, - format=ActionFormat.DEFAULT, - ), - # right_arm - ActionConfig( - rep=ActionRepresentation.RELATIVE, - type=ActionType.NON_EEF, - format=ActionFormat.DEFAULT, - ), - # left_hand - ActionConfig( - rep=ActionRepresentation.ABSOLUTE, # G1 hand is controlled by binary signals like a gripper - type=ActionType.NON_EEF, - format=ActionFormat.DEFAULT, - ), - # right_hand - ActionConfig( - rep=ActionRepresentation.ABSOLUTE, # G1 hand is controlled by binary signals like a gripper - type=ActionType.NON_EEF, - format=ActionFormat.DEFAULT, - ), - # waist - ActionConfig( - rep=ActionRepresentation.ABSOLUTE, - type=ActionType.NON_EEF, - format=ActionFormat.DEFAULT, - ), - # base_height_command - ActionConfig( - rep=ActionRepresentation.ABSOLUTE, - type=ActionType.NON_EEF, - format=ActionFormat.DEFAULT, - ), - # navigate_command - ActionConfig( - rep=ActionRepresentation.ABSOLUTE, - type=ActionType.NON_EEF, - format=ActionFormat.DEFAULT, - ), - ], - ), - "language": ModalityConfig( - delta_indices=[0], - modality_keys=["annotation.human.task_description"], - ), - }, - "libero_sim": { - "video": ModalityConfig( - delta_indices=[0], - modality_keys=["image", "wrist_image"], - ), - "state": ModalityConfig( - delta_indices=[0], - modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "gripper"], - ), - "action": ModalityConfig( - delta_indices=list(range(16)), - modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "gripper"], - ), - "language": ModalityConfig( - delta_indices=[0], - modality_keys=["annotation.human.action.task_description"], - ), - }, - "simpler_env_widowx": { - "video": ModalityConfig( - delta_indices=[0], - modality_keys=["image_0"], - ), - "state": ModalityConfig( - delta_indices=[0], - modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"], - ), - "action": ModalityConfig( - delta_indices=list(range(8)), - modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "gripper"], - ), - "language": ModalityConfig( - delta_indices=[0], - modality_keys=["annotation.human.action.task_description"], - ), - }, - "simpler_env_google": { - "video": ModalityConfig( - delta_indices=[0], - modality_keys=["image"], - ), - "state": ModalityConfig( - delta_indices=[0], - modality_keys=["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"], - ), - "action": ModalityConfig( - delta_indices=list(range(8)), - modality_keys=["x", "y", "z", "roll", "pitch", "yaw", "gripper"], - ), - "language": ModalityConfig( - delta_indices=[0], - modality_keys=["annotation.human.action.task_description"], - ), - }, -} +__all__ = ["ActionConfig", "ActionFormat", "ActionRepresentation", "ActionType", "ModalityConfig"] diff --git a/vllm_omni/diffusion/models/gr00t/configs/gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/configs/gr00t_n1d7.py index dc08a0cffe3..9c6995b6c5a 100644 --- a/vllm_omni/diffusion/models/gr00t/configs/gr00t_n1d7.py +++ b/vllm_omni/diffusion/models/gr00t/configs/gr00t_n1d7.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import MISSING, dataclass - from transformers import PretrainedConfig from . import register_model_config @@ -32,93 +30,92 @@ def _default_diffusion_model_cfg() -> dict: } -@dataclass class Gr00tN1d7Config(PretrainedConfig): - """Unified configuration for Gr00tN1d7 model with backbone and action head. - - Gr00tN1d7 uses the Cosmos-Reason2-2B (Qwen3-VL architecture) VLM backbone, - replacing the Eagle backbone used in Gr00tN1d6. - """ - - # Model identification - model_type: str = "Gr00tN1d7" - model_dtype: str = "bfloat16" # Use bfloat16 for Flash Attention compatibility - - # Backbone configuration - model_name: str = "nvidia/Cosmos-Reason2-2B" - backbone_model_type: str = "qwen" - model_revision: str | None = None - backbone_embedding_dim: int = 2048 # project_to_dim; must match Cosmos-Reason2-2B hidden size - select_layer: int = 12 - reproject_vision: bool = False - use_flash_attention: bool = True - load_bf16: bool = False # Enable BF16 loading - - ### Processing parameters - image_crop_size: tuple[int, int] | None = (230, 230) - image_target_size: tuple[int, int] | None = (256, 256) - - shortest_image_edge: int | None = None - crop_fraction: float | None = None - - random_rotation_angle: int | None = None - color_jitter_params: dict[str, float] | None = None - formalize_language: bool = True - apply_sincos_state_encoding: bool = False # Global flag to enable per-embodiment sin/cos encoding - use_percentiles: bool = True - use_relative_action: bool = False - - # Action head configuration parameters - max_state_dim: int = 132 # Default from state_shape - max_action_dim: int = 132 # Default from action_shape - action_horizon: int = 40 - hidden_size: int = 1024 - input_embedding_dim: int = 1536 - - # State history: number of consecutive state timesteps fed to the state encoder - state_history_length: int = 1 - - # Global parameters - add_pos_embed: bool = True - attn_dropout: float = 0.2 - use_vlln: bool = True - max_seq_len: int = 1024 - use_alternate_vl_dit: bool = True # True for AlternateVLDiT, False for DiT - attend_text_every_n_blocks: int = 2 - - diffusion_model_cfg: dict | None = None - vl_self_attention_cfg: dict | None = None - - # Flow matching parameters - num_inference_timesteps: int = 4 - num_timestep_buckets: int = 1000 - - # State augmentation parameters (inference-relevant only) - exclude_state: bool = False # Zero out all state inputs (ablation) - use_mean_std: bool = False # Use mean/std normalization instead of min/max - - # Multi-embodiment parameters - max_num_embodiments: int = 32 - - def __init__(self, **kwargs): + model_type = "Gr00tN1d7" + + def __init__( + self, + model_dtype: str = "bfloat16", + model_name: str = "nvidia/Cosmos-Reason2-2B", + backbone_model_type: str = "qwen", + model_revision: str | None = None, + backbone_embedding_dim: int = 2048, + select_layer: int = 12, + reproject_vision: bool = False, + use_flash_attention: bool = True, + load_bf16: bool = False, + image_crop_size: tuple | None = (230, 230), + image_target_size: tuple | None = (256, 256), + shortest_image_edge: int | None = None, + crop_fraction: float | None = None, + random_rotation_angle: int | None = None, + color_jitter_params: dict | None = None, + formalize_language: bool = True, + apply_sincos_state_encoding: bool = False, + use_percentiles: bool = True, + use_relative_action: bool = False, + max_state_dim: int = 132, + max_action_dim: int = 132, + action_horizon: int = 40, + hidden_size: int = 1024, + input_embedding_dim: int = 1536, + state_history_length: int = 1, + add_pos_embed: bool = True, + attn_dropout: float = 0.2, + use_vlln: bool = True, + max_seq_len: int = 1024, + use_alternate_vl_dit: bool = True, + attend_text_every_n_blocks: int = 2, + diffusion_model_cfg: dict | None = None, + vl_self_attention_cfg: dict | None = None, + num_inference_timesteps: int = 4, + num_timestep_buckets: int = 1000, + exclude_state: bool = False, + use_mean_std: bool = False, + max_num_embodiments: int = 32, + **kwargs, + ): super().__init__(**kwargs) - for key, value in kwargs.items(): - setattr(self, key, value) - - # Ensures that all dataclass defaults (including those using default_factory) - # are explicitly assigned to the instance, even if dataclasses initialization or subclassing - # (PretrainedConfig) interferes with normal default injection. - for f in self.__dataclass_fields__.values(): - if not hasattr(self, f.name): - if f.default is not MISSING: - setattr(self, f.name, f.default) - elif getattr(f, "default_factory", MISSING) is not MISSING: - setattr(self, f.name, f.default_factory()) - - if self.diffusion_model_cfg is None: - self.diffusion_model_cfg = _default_diffusion_model_cfg() - else: - self.diffusion_model_cfg = dict(self.diffusion_model_cfg) + self.model_dtype = model_dtype + self.model_name = model_name + self.backbone_model_type = backbone_model_type + self.model_revision = model_revision + self.backbone_embedding_dim = backbone_embedding_dim + self.select_layer = select_layer + self.reproject_vision = reproject_vision + self.use_flash_attention = use_flash_attention + self.load_bf16 = load_bf16 + self.image_crop_size = image_crop_size + self.image_target_size = image_target_size + self.shortest_image_edge = shortest_image_edge + self.crop_fraction = crop_fraction + self.random_rotation_angle = random_rotation_angle + self.color_jitter_params = color_jitter_params + self.formalize_language = formalize_language + self.apply_sincos_state_encoding = apply_sincos_state_encoding + self.use_percentiles = use_percentiles + self.use_relative_action = use_relative_action + self.max_state_dim = max_state_dim + self.max_action_dim = max_action_dim + self.action_horizon = action_horizon + self.hidden_size = hidden_size + self.input_embedding_dim = input_embedding_dim + self.state_history_length = state_history_length + self.add_pos_embed = add_pos_embed + self.attn_dropout = attn_dropout + self.use_vlln = use_vlln + self.max_seq_len = max_seq_len + self.use_alternate_vl_dit = use_alternate_vl_dit + self.attend_text_every_n_blocks = attend_text_every_n_blocks + self.diffusion_model_cfg = ( + diffusion_model_cfg if diffusion_model_cfg is not None else _default_diffusion_model_cfg() + ) + self.vl_self_attention_cfg = vl_self_attention_cfg + self.num_inference_timesteps = num_inference_timesteps + self.num_timestep_buckets = num_timestep_buckets + self.exclude_state = exclude_state + self.use_mean_std = use_mean_std + self.max_num_embodiments = max_num_embodiments register_model_config("Gr00tN1d7", Gr00tN1d7Config) diff --git a/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py index 5b14c61c7a7..8bb15549067 100644 --- a/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py +++ b/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py @@ -21,9 +21,9 @@ from torch import nn from transformers import AutoConfig, AutoModel, PreTrainedModel from transformers.feature_extraction_utils import BatchFeature -from vllm.vllm_flash_attn import FA2_AVAILABLE as _FA2_AVAILABLE -from vllm.vllm_flash_attn import FA3_AVAILABLE as _FA3_AVAILABLE -from vllm.vllm_flash_attn import is_fa_version_supported as _is_fa_version_supported +from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLForConditionalGeneration as _Qwen3VLForConditionalGeneration, +) from vllm_omni.diffusion.models.gr00t.configs.gr00t_n1d7 import Gr00tN1d7Config from vllm_omni.diffusion.models.gr00t.modeling.modules.dit import AlternateVLDiT, DiT, SelfAttentionTransformer @@ -31,7 +31,7 @@ CategorySpecificMLP, MultiEmbodimentActionEncoder, ) -from vllm_omni.diffusion.models.internvla_a1.adapter_qwen3_vl import Qwen3VLForConditionalGeneration +from vllm_omni.diffusion.models.gr00t.modeling.processing_gr00t_n1d7 import Gr00tN1d7DataCollator logger = logging.getLogger(__name__) @@ -314,37 +314,17 @@ def __init__( self, model_name: str, select_layer: int, - reproject_vision: bool, - use_flash_attention: bool, backbone_embedding_dim: int, load_bf16: bool, transformers_loading_kwargs: dict[str, Any] | None = None, ): super().__init__() - del reproject_vision - - if use_flash_attention: - if _FA3_AVAILABLE and _is_fa_version_supported(3): - attn_implementation = "flash_attention_3" - elif _FA2_AVAILABLE and _is_fa_version_supported(2): - attn_implementation = "flash_attention_2" - else: - logger.warning("No supported flash attention backend on this device, falling back to sdpa.") - attn_implementation = "sdpa" - else: - attn_implementation = "sdpa" - backbone_config = AutoConfig.from_pretrained( model_name, **(transformers_loading_kwargs or {"trust_remote_code": True}) ) - self.model = Qwen3VLForConditionalGeneration(backbone_config).eval() - # Set attention implementation post-init — avoids transformers' init-time flash_attn check. - # Text attention layers go through adapter_qwen3_vl which calls vllm_flash_attn directly. - self.model.config.text_config._attn_implementation = attn_implementation - self.model.model.language_model.config._attn_implementation = attn_implementation - # Vision model uses unpatched transformers attention — sdpa is the safe fallback - # (flash_attn and flash_attn_interface packages are absent; only fa3_fwd_interface is present). - self.model.config.vision_config._attn_implementation = "sdpa" + backbone_config.text_config._attn_implementation = "sdpa" + backbone_config.vision_config._attn_implementation = "sdpa" + self.model = _Qwen3VLForConditionalGeneration(backbone_config).eval() if load_bf16: self.model.to(dtype=torch.bfloat16) @@ -402,10 +382,9 @@ def _pre_norm_hook(_module, args, _out): def get_backbone_cls(config: Gr00tN1d7Config): - if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name: + if "nvidia/Cosmos-Reason2" in config.model_name: return _Qwen3VLBackbone - else: - raise ValueError(f"Unsupported model name: {config.model_name}") + raise ValueError(f"Unsupported model name: {config.model_name}") class Gr00tN1d7(PreTrainedModel): @@ -433,8 +412,6 @@ def __init__( self.backbone = backbone_cls( model_name=config.model_name, select_layer=config.select_layer, - reproject_vision=config.reproject_vision, - use_flash_attention=config.use_flash_attention, backbone_embedding_dim=config.backbone_embedding_dim, load_bf16=config.load_bf16, transformers_loading_kwargs=transformers_loading_kwargs, @@ -442,7 +419,6 @@ def __init__( # Initialize action head self.action_head = Gr00tN1d7ActionHead(config) - from .processing_gr00t_n1d7 import Gr00tN1d7DataCollator self.collator = Gr00tN1d7DataCollator( model_name=config.model_name, diff --git a/vllm_omni/diffusion/models/gr00t/policy.py b/vllm_omni/diffusion/models/gr00t/policy.py index 738988cae00..ae10ad187b4 100644 --- a/vllm_omni/diffusion/models/gr00t/policy.py +++ b/vllm_omni/diffusion/models/gr00t/policy.py @@ -22,6 +22,8 @@ from vllm_omni.diffusion.models.gr00t.dataio.embodiment_tags import FINETUNE_ONLY_TAGS, POSTTRAIN_TAGS, EmbodimentTag from vllm_omni.diffusion.models.gr00t.dataio.types import MessageType, ModalityConfig, VLAStepData +from vllm_omni.diffusion.models.gr00t.modeling.gr00t_n1d7 import Gr00tN1d7 +from vllm_omni.diffusion.models.gr00t.modeling.processing_gr00t_n1d7 import Gr00tN1d7Processor def _rec_to_dtype(value: Any, dtype: torch.dtype) -> Any: @@ -65,22 +67,16 @@ def __init__( device: Device to run the model on (e.g., 'cuda:0', 0, 'cpu') strict: Whether to enforce strict input validation (default: True) """ - # Import these local modules to register GR00T with Hugging Face Auto classes. - from vllm_omni.diffusion.models.gr00t.modeling.gr00t_n1d7 import Gr00tN1d7 # noqa: F401 - from vllm_omni.diffusion.models.gr00t.modeling.processing_gr00t_n1d7 import Gr00tN1d7Processor # noqa: F401 - self.strict = strict if isinstance(embodiment_tag, str): embodiment_tag = EmbodimentTag.resolve(embodiment_tag) model_dir = Path(model_path) # Load the pretrained model and move to target device with bfloat16 precision - model = AutoModel.from_pretrained(model_dir) - model.eval() # Set model to evaluation mode - model.to(device=device, dtype=torch.bfloat16) - self.model = model + self.model: Gr00tN1d7 = AutoModel.from_pretrained(model_dir) + self.model.eval() + self.model.to(device=device, dtype=torch.bfloat16) - # Load the processor for input/output transformation. # Training saves processor files under a "processor/" subdirectory, but # AutoProcessor expects them at the model root. Fall back to the # subdirectory when the root lacks a processor_config.json. @@ -89,7 +85,7 @@ def __init__( if (model_dir / "processor").is_dir() and not (model_dir / "processor_config.json").exists() else model_dir ) - self.processor = AutoProcessor.from_pretrained(processor_dir) + self.processor: Gr00tN1d7Processor = AutoProcessor.from_pretrained(processor_dir) # Store embodiment-specific configurations self.embodiment_tag = embodiment_tag diff --git a/vllm_omni/diffusion/models/internvla_a1/adapter_qwen3_vl.py b/vllm_omni/diffusion/models/internvla_a1/adapter_qwen3_vl.py index 065596ced98..a820fb56aec 100644 --- a/vllm_omni/diffusion/models/internvla_a1/adapter_qwen3_vl.py +++ b/vllm_omni/diffusion/models/internvla_a1/adapter_qwen3_vl.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import torch import torch.nn as nn from transformers.models.qwen3_vl.modeling_qwen3_vl import ( @@ -13,7 +15,9 @@ Qwen3VLVisionModel, Unpack, apply_rotary_pos_emb, + check_model_inputs, create_causal_mask, + deprecate_kwarg, eager_attention_forward, ) from transformers.models.qwen3_vl.modeling_qwen3_vl import ( @@ -34,50 +38,6 @@ from transformers.models.qwen3_vl.modeling_qwen3_vl import ( Qwen3VLTextRMSNorm as HFQwen3VLTextRMSNorm, ) -from vllm.vllm_flash_attn import FA2_AVAILABLE as _FA2_AVAILABLE -from vllm.vllm_flash_attn import FA3_AVAILABLE as _FA3_AVAILABLE -from vllm.vllm_flash_attn import flash_attn_varlen_func as _vllm_fa_varlen -from vllm.vllm_flash_attn import is_fa_version_supported as _is_fa_version_supported - -_VLLM_FA3_OK: bool = _FA3_AVAILABLE and _is_fa_version_supported(3) -_VLLM_FA2_OK: bool = _FA2_AVAILABLE and _is_fa_version_supported(2) -_VLLM_FA_AVAILABLE: bool = _VLLM_FA3_OK or _VLLM_FA2_OK -_VLLM_FA_VERSION: int = 3 if _VLLM_FA3_OK else 2 - -_FLASH_IMPL_NAMES = frozenset({"flash_attention_2", "flash_attention_3", "flash_attention_4"}) - - -def _vllm_flash_attn_forward( - module, - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - attention_mask, # ignored — causal handled by flash attn - dropout: float = 0.0, - scaling: float | None = None, - **kwargs, -) -> tuple[torch.Tensor, None]: - batch_size, n_heads, seq_len, head_dim = query_states.shape - n_kv_heads = key_states.shape[1] - # (batch, heads, seq, dim) → (batch*seq, heads, dim) - q = query_states.transpose(1, 2).contiguous().view(batch_size * seq_len, n_heads, head_dim) - k = key_states.transpose(1, 2).contiguous().view(batch_size * seq_len, n_kv_heads, head_dim) - v = value_states.transpose(1, 2).contiguous().view(batch_size * seq_len, n_kv_heads, head_dim) - cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, seq_len, dtype=torch.int32, device=q.device) - out = _vllm_fa_varlen( - q, - k, - v, - max_seqlen_q=seq_len, - cu_seqlens_q=cu_seqlens, - max_seqlen_k=seq_len, - cu_seqlens_k=cu_seqlens, - softmax_scale=scaling, - causal=True, - fa_version=_VLLM_FA_VERSION, - ) - # (batch*seq, heads, dim) → (batch, seq, heads, dim) to match transformers interface - return out.view(batch_size, seq_len, n_heads, head_dim), None class Qwen3VLTextRMSNorm(HFQwen3VLTextRMSNorm): @@ -95,6 +55,7 @@ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): self.q_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -125,31 +86,20 @@ def forward( key_states = torch.cat([past_key_values[self.layer_idx][0], key_states], dim=2) value_states = torch.cat([past_key_values[self.layer_idx][1], value_states], dim=2) - if _VLLM_FA_AVAILABLE and self.config._attn_implementation in _FLASH_IMPL_NAMES: - attn_output, attn_weights = _vllm_flash_attn_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface( - self.config._attn_implementation, eager_attention_forward - ) - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, - ) + attention_interface = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -184,6 +134,7 @@ def __init__(self, config: Qwen3VLTextConfig): self.post_init() + @check_model_inputs def forward( self, input_ids: torch.LongTensor = None, @@ -225,8 +176,9 @@ def forward( attention_mask = create_causal_mask( config=self.config, - inputs_embeds=inputs_embeds, + input_embeds=inputs_embeds, attention_mask=attention_mask, + cache_position=cache_position, past_key_values=past_key_values, position_ids=text_position_ids, ) @@ -278,6 +230,7 @@ def __init__(self, config): class Qwen3VLForConditionalGeneration(HFQwen3VLForConditionalGeneration): _checkpoint_conversion_mapping = {} + _tied_weights_keys = ["lm_head.weight"] accepts_loss_kwargs = False config: Qwen3VLConfig From 18fd4befaf5947dc0e97c670ccc222dff0b633d3 Mon Sep 17 00:00:00 2001 From: Zhengyuan Su Date: Mon, 1 Jun 2026 23:17:23 +0700 Subject: [PATCH 8/8] trim docstrings Signed-off-by: Zhengyuan Su --- .../gr00t/configs/embodiment_configs.py | 14 +------------ .../models/gr00t/configs/gr00t_n1d7.py | 14 +------------ .../models/gr00t/dataio/embodiment_tags.py | 18 +---------------- .../dataio/state_action/action_chunking.py | 14 +------------ .../models/gr00t/dataio/state_action/pose.py | 14 +------------ .../state_action/state_action_processor.py | 14 +------------ .../diffusion/models/gr00t/dataio/types.py | 14 +------------ .../diffusion/models/gr00t/dataio/utils.py | 14 +------------ .../models/gr00t/modeling/gr00t_n1d7.py | 20 ++----------------- .../models/gr00t/modeling/modules/dit.py | 14 +------------ .../modules/embodiment_conditioned_mlp.py | 14 +------------ .../modeling/modules/flowmatching_modules.py | 14 +------------ .../gr00t/modeling/processing_gr00t_n1d7.py | 16 +-------------- vllm_omni/diffusion/models/gr00t/policy.py | 14 +------------ 14 files changed, 15 insertions(+), 193 deletions(-) diff --git a/vllm_omni/diffusion/models/gr00t/configs/embodiment_configs.py b/vllm_omni/diffusion/models/gr00t/configs/embodiment_configs.py index 5fd683a26e0..7d11a651cf6 100644 --- a/vllm_omni/diffusion/models/gr00t/configs/embodiment_configs.py +++ b/vllm_omni/diffusion/models/gr00t/configs/embodiment_configs.py @@ -1,17 +1,5 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm_omni.diffusion.models.gr00t.dataio.types import ( ActionConfig, diff --git a/vllm_omni/diffusion/models/gr00t/configs/gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/configs/gr00t_n1d7.py index 9c6995b6c5a..d7d4527fb93 100644 --- a/vllm_omni/diffusion/models/gr00t/configs/gr00t_n1d7.py +++ b/vllm_omni/diffusion/models/gr00t/configs/gr00t_n1d7.py @@ -1,17 +1,5 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from transformers import PretrainedConfig diff --git a/vllm_omni/diffusion/models/gr00t/dataio/embodiment_tags.py b/vllm_omni/diffusion/models/gr00t/dataio/embodiment_tags.py index 939d52c1c01..3ed0c9b0360 100755 --- a/vllm_omni/diffusion/models/gr00t/dataio/embodiment_tags.py +++ b/vllm_omni/diffusion/models/gr00t/dataio/embodiment_tags.py @@ -1,17 +1,5 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import Enum @@ -52,8 +40,6 @@ class EmbodimentTag(Enum): case-insensitively. """ - ##### Pretrain embodiment tags (in base model processor_config.json) ##### - OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT = "oxe_droid_relative_eef_relative_joint" """ The Open-X-Embodiment DROID robot with relative EEF and relative joint position actions. @@ -94,8 +80,6 @@ class EmbodimentTag(Enum): Real-world R1 Pro Sharpa with relative EEF actions (Mecka data, single-cam). """ - ##### Pre-registered posttrain embodiment tags ##### - UNITREE_G1 = "unitree_g1_full_body_with_waist_height_nav_cmd" """ The Unitree G1 robot (sim, full-body with waist height and nav commands). diff --git a/vllm_omni/diffusion/models/gr00t/dataio/state_action/action_chunking.py b/vllm_omni/diffusion/models/gr00t/dataio/state_action/action_chunking.py index 28a9f030df6..4128d6cae6c 100644 --- a/vllm_omni/diffusion/models/gr00t/dataio/state_action/action_chunking.py +++ b/vllm_omni/diffusion/models/gr00t/dataio/state_action/action_chunking.py @@ -1,17 +1,5 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence from typing import Generic, TypeVar diff --git a/vllm_omni/diffusion/models/gr00t/dataio/state_action/pose.py b/vllm_omni/diffusion/models/gr00t/dataio/state_action/pose.py index 1cc714a3e2b..0e0849f17b4 100644 --- a/vllm_omni/diffusion/models/gr00t/dataio/state_action/pose.py +++ b/vllm_omni/diffusion/models/gr00t/dataio/state_action/pose.py @@ -1,17 +1,5 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import Enum from typing import TypeVar diff --git a/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py b/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py index cce6362e4e5..14e2b0b3a37 100644 --- a/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py +++ b/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py @@ -1,17 +1,5 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Unified state and action processor for robotics.""" diff --git a/vllm_omni/diffusion/models/gr00t/dataio/types.py b/vllm_omni/diffusion/models/gr00t/dataio/types.py index 09e32e0a7e5..d4948a1c234 100644 --- a/vllm_omni/diffusion/models/gr00t/dataio/types.py +++ b/vllm_omni/diffusion/models/gr00t/dataio/types.py @@ -1,17 +1,5 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass, field from enum import Enum diff --git a/vllm_omni/diffusion/models/gr00t/dataio/utils.py b/vllm_omni/diffusion/models/gr00t/dataio/utils.py index 379b0d87606..8096d62396c 100644 --- a/vllm_omni/diffusion/models/gr00t/dataio/utils.py +++ b/vllm_omni/diffusion/models/gr00t/dataio/utils.py @@ -1,17 +1,5 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import asdict, is_dataclass from enum import Enum diff --git a/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py index 8bb15549067..4799b7cf615 100644 --- a/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py +++ b/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py @@ -1,17 +1,5 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging import os @@ -122,18 +110,15 @@ def _encode_features(self, backbone_output: BatchFeature, action_input: BatchFea """ backbone_output = self.process_backbone_output(backbone_output) - # Get vision and language embeddings. vl_embeds = backbone_output.backbone_features embodiment_id = action_input.embodiment_id - # Handle state history: if we have fewer timesteps than expected, repeat to fill state = action_input.state current_T = state.shape[1] assert current_T == self.config.state_history_length, "current_T != state_history_length" - # Reshape state from [B, state_history_length, max_state_dim] to [B, 1, state_history_length * max_state_dim] + # [B, state_history_length, max_state_dim] -> [B, 1, state_history_length * max_state_dim] state = state.view(state.shape[0], 1, -1) - # Embed state. state_features = self.state_encoder(state, embodiment_id) return BatchFeature(data={"backbone_features": vl_embeds, "state_features": state_features}) @@ -159,7 +144,6 @@ def get_action_with_features( """ vl_embeds = backbone_features - # Set initial actions as the sampled noise. batch_size = vl_embeds.shape[0] device = vl_embeds.device _seed_env = os.environ.get("GR00T_NOISE_SEED") diff --git a/vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py b/vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py index db424adab3b..54fbe770d0d 100644 --- a/vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py +++ b/vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py @@ -1,17 +1,5 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch import torch.nn.functional as F diff --git a/vllm_omni/diffusion/models/gr00t/modeling/modules/embodiment_conditioned_mlp.py b/vllm_omni/diffusion/models/gr00t/modeling/modules/embodiment_conditioned_mlp.py index b740f4eadd0..17f0ed661f7 100644 --- a/vllm_omni/diffusion/models/gr00t/modeling/modules/embodiment_conditioned_mlp.py +++ b/vllm_omni/diffusion/models/gr00t/modeling/modules/embodiment_conditioned_mlp.py @@ -1,17 +1,5 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch import torch.nn.functional as F diff --git a/vllm_omni/diffusion/models/gr00t/modeling/modules/flowmatching_modules.py b/vllm_omni/diffusion/models/gr00t/modeling/modules/flowmatching_modules.py index 20f2bc06aed..86b6c4b344b 100644 --- a/vllm_omni/diffusion/models/gr00t/modeling/modules/flowmatching_modules.py +++ b/vllm_omni/diffusion/models/gr00t/modeling/modules/flowmatching_modules.py @@ -1,17 +1,5 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from torch import nn diff --git a/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py index fb60b9b2afd..d26ed38be86 100644 --- a/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py +++ b/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py @@ -1,17 +1,5 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json import os @@ -122,7 +110,6 @@ def __init__( transformers_loading_kwargs: dict = {}, ): self.processor = build_processor(model_name, transformers_loading_kwargs) - # Set padding side to 'left' for Flash Attention compatibility self.processor.tokenizer.padding_side = "left" self.model_type = model_type self.model_name = model_name @@ -231,7 +218,6 @@ def __init__( self.random_rotation_angle = random_rotation_angle self.color_jitter_params = color_jitter_params self.processor = build_processor(model_name, transformers_loading_kwargs) - # Set padding side to 'left' for Flash Attention compatibility self.processor.tokenizer.padding_side = "left" self.embodiment_id_mapping = embodiment_id_mapping or EMBODIMENT_TAG_TO_PROJECTOR_INDEX for k, v in EMBODIMENT_TAG_TO_PROJECTOR_INDEX.items(): diff --git a/vllm_omni/diffusion/models/gr00t/policy.py b/vllm_omni/diffusion/models/gr00t/policy.py index ae10ad187b4..b37c2eb46ba 100644 --- a/vllm_omni/diffusion/models/gr00t/policy.py +++ b/vllm_omni/diffusion/models/gr00t/policy.py @@ -1,17 +1,5 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from pathlib import Path from typing import Any