diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index f58123317de..3b1e60b47ae 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -22,6 +22,7 @@ th {
| `MingFlashOmniForConditionalGeneration` + `MingImagePipeline` | Ming-flash-omni-2.0 (omni-speech + imagegen1) | `Jonathan1909/Ming-flash-omni-2.0` | ✅︎ | | | |
| `BagelForConditionalGeneration` | BAGEL (DiT-only) | `ByteDance-Seed/BAGEL-7B-MoT` | ✅︎ | ✅︎ | | ✅︎ |
| `InternVLAA1Pipeline` | InternVLA-A1 | `InternRobotics/InternVLA-A1-3B` | ✅︎ | ✅︎ | | |
+| `Gr00tN1d7Pipeline` | GR00T N1.7 | `nvidia/GR00T-N1.7-3B` | ✅︎ | | | |
| `HunyuanImage3ForCausalMM` | HunyuanImage3.0 (DiT-only) | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` | ✅︎ | ✅︎ | ✅︎ | ✅︎ |
| `QwenImagePipeline` | Qwen-Image | `Qwen/Qwen-Image` | ✅︎ | ✅︎ | ✅︎ | ✅︎ |
| `QwenImagePipeline` | Qwen-Image-2512 | `Qwen/Qwen-Image-2512` | ✅︎ | ✅︎ | ✅︎ | ✅︎ |
diff --git a/recipes/NVIDIA/GR00T-N1.7.md b/recipes/NVIDIA/GR00T-N1.7.md
new file mode 100644
index 00000000000..0a314954026
--- /dev/null
+++ b/recipes/NVIDIA/GR00T-N1.7.md
@@ -0,0 +1,86 @@
+# GR00T-N1.7
+
+> NVIDIA Isaac GR00T-N1.7-3B robot VLA policy served over the OpenPI WebSocket protocol
+
+## Summary
+
+- Vendor: NVIDIA
+- Model: `nvidia/GR00T-N1.7-3B`
+- Task: Vision-Language-Action (VLA) inference for robot manipulation
+- Mode: Online serving via OpenPI WebSocket endpoint
+- Maintainer: timzsu
+
+## When to use this recipe
+
+Use this recipe when you need to serve GR00T-N1.7 as a real-time robot policy
+over the OpenPI WebSocket API. It configures the DROID embodiment
+(`OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT`) and exposes the standard DROID action
+keys (`eef_9d`, `gripper_position`, `joint_position`) with action horizon 40.
+
+## References
+
+- Upstream model:
+- Upstream codebase:
+- OpenPI client library:
+- Pipeline: `vllm_omni.diffusion.models.gr00t.pipeline_gr00t.Gr00tN1d7Pipeline`
+- Deploy config: [`vllm_omni/deploy/Gr00tN1d7.yaml`](../../vllm_omni/deploy/Gr00tN1d7.yaml)
+- E2E test: [`tests/e2e/online_serving/test_gr00t_openpi.py`](../../tests/e2e/online_serving/test_gr00t_openpi.py)
+
+## Environment
+
+- OS: Linux
+- Python: 3.11+
+- Driver / runtime: NVIDIA CUDA
+- vLLM-Omni version or commit: use versions from your current checkout
+
+## Start server
+
+From repository root:
+
+```bash
+vllm serve nvidia/GR00T-N1.7-3B \
+ --omni \
+ --host 127.0.0.1 \
+ --port 8000 \
+ --served-model-name gr00t-n1d7 \
+ --stage-configs-path vllm_omni/deploy/Gr00tN1d7.yaml
+```
+
+Notes:
+
+- Only `max_num_seqs: 1` is supported (configured in the deploy YAML); GR00T
+ policy state is per-session and not designed for concurrent batching.
+- The WebSocket endpoint is `ws://127.0.0.1:8000/v1/realtime/robot/openpi`.
+ The server handshake message (first frame after connect) is a msgpack-encoded
+ dict with `action_horizon`, `action_keys`, `embodiment_tag`, and
+ `needs_session_id`.
+
+## Verification
+
+```python
+from tests.gr00t.openpi_client_helper import run_policy_session, validate_session_result
+validate_session_result(run_policy_session(host="127.0.0.1", port=8000))
+```
+
+Or run the e2e test suite:
+
+```bash
+python -m pytest tests/e2e/online_serving/test_gr00t_openpi.py -v
+```
+
+The test sends a synthetic two-frame DROID observation and checks:
+
+- GR00T metadata contract: `image_resolution`, `action_horizon`, `action_keys`, `embodiment_tag`
+- Action shapes: `eef_9d (1,40,9)`, `gripper_position (1,40,1)`, `joint_position (1,40,7)`
+- All action values are finite float32
+- Reset response is `"reset successful"`
+
+## Notes
+
+- To switch embodiment, edit `embodiment_tag` under both `model_config` and
+ `policy_server_config` in `vllm_omni/deploy/Gr00tN1d7.yaml`. Supported values:
+ `OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT` (default), `XDOF`, `XDOF_SUBTASK`,
+ `REAL_G1`, `REAL_R1_PRO_SHARPA`, `LIBERO_PANDA`, `SIMPLER_ENV_GOOGLE`,
+ `SIMPLER_ENV_WIDOWX`.
+- GR00T weights are loaded directly by `Gr00tPolicy` via `AutoModel.from_pretrained`;
+ the pipeline's `load_weights` is intentionally a no-op.
diff --git a/tests/diffusion/models/gr00t/__init__.py b/tests/diffusion/models/gr00t/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/tests/diffusion/models/gr00t/test_pipeline.py b/tests/diffusion/models/gr00t/test_pipeline.py
new file mode 100644
index 00000000000..dd75773ec3f
--- /dev/null
+++ b/tests/diffusion/models/gr00t/test_pipeline.py
@@ -0,0 +1,142 @@
+from types import SimpleNamespace
+
+import numpy as np
+import pytest
+
+from vllm_omni.diffusion.models.gr00t import pipeline_gr00t
+from vllm_omni.diffusion.models.gr00t.pipeline_gr00t import Gr00tN1d7Pipeline
+from vllm_omni.diffusion.request import OmniDiffusionRequest
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+class FakeGr00tPolicy:
+ instances = []
+
+ def __init__(self, **kwargs):
+ self.kwargs = kwargs
+ self.reset_calls = 0
+ self.seen_obs = None
+ self.embodiment_tag = SimpleNamespace(value="fake_embodiment")
+ self.language_key = "annotation.language.language_instruction"
+ self.modality_configs = {
+ "action": SimpleNamespace(
+ delta_indices=[0, 1],
+ modality_keys=["arm", "gripper"],
+ )
+ }
+ self.processor = SimpleNamespace(
+ state_action_processor=SimpleNamespace(
+ norm_params={
+ "fake_embodiment": {
+ "action": {
+ "arm": {"dim": np.array(2)},
+ "gripper": {"dim": np.array(1)},
+ }
+ }
+ }
+ )
+ )
+ FakeGr00tPolicy.instances.append(self)
+
+ def get_action(self, obs):
+ self.seen_obs = obs
+ return {
+ "arm": np.array([[[1.0, 2.0]]], dtype=np.float64),
+ "gripper": [[[3.0]]],
+ }, {"latency_ms": 1.0}
+
+ def reset(self):
+ self.reset_calls += 1
+ return {"reset": True}
+
+
+@pytest.fixture(autouse=True)
+def fake_gr00t_policy(monkeypatch):
+ FakeGr00tPolicy.instances.clear()
+ monkeypatch.setattr(pipeline_gr00t, "Gr00tPolicy", FakeGr00tPolicy)
+
+
+def _pipeline():
+ od_config = SimpleNamespace(
+ model="nvidia/GR00T-N1.7-3B",
+ model_config={
+ "embodiment_tag": "LIBERO_PANDA",
+ "strict": False,
+ },
+ custom_pipeline_args={},
+ )
+ return Gr00tN1d7Pipeline(od_config=od_config)
+
+
+def test_pipeline_initializes_local_policy():
+ pipeline = _pipeline()
+
+ policy = FakeGr00tPolicy.instances[0]
+ assert policy.kwargs["model_path"] == "nvidia/GR00T-N1.7-3B"
+ assert policy.kwargs["embodiment_tag"] == "LIBERO_PANDA"
+ assert policy.kwargs["strict"] is False
+ assert pipeline.weights_sources == ()
+ assert pipeline.load_weights(iter(())) == set()
+
+
+def test_forward_returns_dict_actions_in_multimodal_output():
+ pipeline = _pipeline()
+ req = OmniDiffusionRequest(
+ prompts=["pick"],
+ request_ids=["req"],
+ sampling_params=OmniDiffusionSamplingParams(
+ extra_args={
+ "robot_obs": {
+ "images": {"cam": np.zeros((1, 1, 8, 8, 3), dtype=np.uint8)},
+ "state": {"joint": np.zeros((1, 1, 2), dtype=np.float32)},
+ "prompt": "pick the cube",
+ "session_id": "session-a",
+ },
+ "reset": True,
+ }
+ ),
+ )
+
+ output = pipeline.forward(req)
+
+ assert output.error is None
+ assert output.output is None
+ actions = output.multimodal_output["actions"]
+ assert set(actions) == {"arm", "gripper"}
+ assert actions["arm"].dtype == np.float32
+ np.testing.assert_allclose(actions["arm"], np.array([[[1.0, 2.0]]], dtype=np.float32))
+ policy = FakeGr00tPolicy.instances[0]
+ assert "video" in policy.seen_obs
+ assert policy.seen_obs["language"] == {"annotation.language.language_instruction": [["pick the cube"]]}
+ assert "images" not in policy.seen_obs
+ assert "prompt" not in policy.seen_obs
+ assert "session_id" not in policy.seen_obs
+ assert policy.reset_calls == 1
+
+
+def test_dummy_warmup_returns_shape_correct_zero_actions():
+ pipeline = _pipeline()
+ req = OmniDiffusionRequest(
+ prompts=["dummy run"],
+ request_ids=["dummy_req_id"],
+ sampling_params=OmniDiffusionSamplingParams(num_inference_steps=1),
+ )
+
+ output = pipeline.forward(req)
+
+ assert output.error is None
+ actions = output.multimodal_output["actions"]
+ assert set(actions) == {"arm", "gripper"}
+ assert actions["arm"].shape == (1, 2, 2)
+ assert actions["gripper"].shape == (1, 2, 1)
+ assert not actions["arm"].any()
+ assert FakeGr00tPolicy.instances[0].seen_obs is None
+
+
+def test_reset_delegates_to_policy():
+ pipeline = _pipeline()
+
+ assert pipeline.reset() == {"reset": True}
+ assert FakeGr00tPolicy.instances[0].reset_calls == 1
diff --git a/tests/e2e/online_serving/test_gr00t_openpi.py b/tests/e2e/online_serving/test_gr00t_openpi.py
new file mode 100644
index 00000000000..51f1a386cca
--- /dev/null
+++ b/tests/e2e/online_serving/test_gr00t_openpi.py
@@ -0,0 +1,161 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""End-to-end online serving test for GR00T N1.7 through the OpenPI robot endpoint."""
+
+import os
+
+import numpy as np
+import pytest
+
+from tests.gr00t import openpi_client_helper as openpi_client
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniServerParams
+from tests.helpers.stage_config import get_deploy_config_path
+
+os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+
+MODEL = "nvidia/GR00T-N1.7-3B"
+openpi_client.require_dependencies()
+
+test_params = [
+ pytest.param(
+ OmniServerParams(
+ model=MODEL,
+ stage_config_path=get_deploy_config_path("Gr00tN1d7.yaml"),
+ server_args=["--disable-log-stats"],
+ env_dict={"VLLM_DISABLE_COMPILE_CACHE": "1", "GR00T_NOISE_SEED": "42"},
+ init_timeout=1200,
+ stage_init_timeout=900,
+ ),
+ id="gr00t-n1d7-openpi",
+ )
+]
+
+# Reference values captured from the Isaac-GR00T ZMQ server (nvidia/GR00T-N1.7-3B,
+# OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT embodiment) using build_droid_observation()
+# inputs: zero images (256×256), identity eef_9d, zero gripper/joint, "pick up the object".
+# Outputs are bit-reproducible across resets and across runs (max_diff=0.0).
+_REF_EEF_9D = np.array(
+ [
+ [
+ 0.014888550154864788,
+ -0.00039258378092199564,
+ -0.013574761338531971,
+ 0.9999837875366211,
+ 0.005678884219378233,
+ -0.00034708858584053814,
+ -0.005677700508385897,
+ 0.9999783635139465,
+ 0.0033210734836757183,
+ ],
+ [
+ 0.020188070833683014,
+ -0.0003105341165792197,
+ -0.02232760563492775,
+ 0.9997346997261047,
+ 0.02301345206797123,
+ 0.000983047066256404,
+ -0.02302766777575016,
+ 0.9995648860931396,
+ 0.018431292846798897,
+ ],
+ [
+ -0.007266733795404434,
+ -0.05537768080830574,
+ 0.03667901083827019,
+ 0.992686927318573,
+ 0.1206541359424591,
+ 0.003906540106981993,
+ -0.12024091929197311,
+ 0.9853788614273071,
+ 0.12070894986391068,
+ ],
+ ],
+ dtype=np.float32,
+) # rows = step 0, step 4, step 39
+
+_REF_GRIPPER = np.array([[0.0], [0.0078125], [0.939453125]], dtype=np.float32) # steps 0, 4, 39
+
+_REF_JOINT = np.array(
+ [
+ [
+ -0.0010484338272362947,
+ 0.0014262489276006818,
+ -0.003565810853615403,
+ -3.846167237497866e-05,
+ -0.0002604846959002316,
+ 0.008521700277924538,
+ -0.006872728932648897,
+ ],
+ [
+ -0.009435676969587803,
+ 0.0021475711837410927,
+ -0.0031688229646533728,
+ -1.8328893929719925e-05,
+ 0.0005945992306806147,
+ 0.019159257411956787,
+ 0.0009468861389905214,
+ ],
+ [
+ -0.02944088727235794,
+ -0.08419207483530045,
+ -0.0251418836414814,
+ 0.00540524534881115,
+ 0.04752273112535477,
+ -0.012884500436484814,
+ 0.024298785254359245,
+ ],
+ ],
+ dtype=np.float32,
+) # rows = step 0, step 4, step 39
+
+
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
+@hardware_test(res={"cuda": "H100"}, num_cards=1)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_gr00t_n1d7_openpi_online(omni_server) -> None:
+ result = openpi_client.run_policy_session(
+ host=omni_server.host,
+ port=omni_server.port,
+ session_id="gr00t-online-e2e",
+ )
+ openpi_client.validate_session_result(result)
+
+
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
+@hardware_test(res={"cuda": "H100"}, num_cards=1)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_gr00t_n1d7_openpi_precision(omni_server) -> None:
+ """Assert actions match Isaac-GR00T reference (GR00T_NOISE_SEED=42, zero inputs)."""
+ client = openpi_client.OpenPIWebsocketClient(host=omni_server.host, port=omni_server.port)
+ try:
+ obs = openpi_client.build_droid_observation(session_id="gr00t-precision-e2e")
+ client.reset({})
+ actions = client.infer(obs)
+ finally:
+ client.close()
+
+ steps = [0, 4, 39]
+ np.testing.assert_allclose(
+ actions["eef_9d"][0, steps, :],
+ _REF_EEF_9D,
+ atol=1e-2,
+ rtol=0.0,
+ err_msg="eef_9d action mismatch vs Isaac-GR00T reference",
+ )
+ np.testing.assert_allclose(
+ actions["gripper_position"][0, steps, :],
+ _REF_GRIPPER,
+ atol=1e-2,
+ rtol=0.0,
+ err_msg="gripper_position action mismatch vs Isaac-GR00T reference",
+ )
+ np.testing.assert_allclose(
+ actions["joint_position"][0, steps, :],
+ _REF_JOINT,
+ atol=1e-2,
+ rtol=0.0,
+ err_msg="joint_position action mismatch vs Isaac-GR00T reference",
+ )
diff --git a/tests/gr00t/__init__.py b/tests/gr00t/__init__.py
new file mode 100644
index 00000000000..208f01a7cb5
--- /dev/null
+++ b/tests/gr00t/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
diff --git a/tests/gr00t/openpi_client_helper.py b/tests/gr00t/openpi_client_helper.py
new file mode 100644
index 00000000000..e5943ca3307
--- /dev/null
+++ b/tests/gr00t/openpi_client_helper.py
@@ -0,0 +1,183 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from dataclasses import dataclass
+from typing import Any
+
+import numpy as np
+
+try:
+ import websockets.sync.client as websockets_client
+except ImportError: # pragma: no cover - optional e2e dependency
+ websockets_client = None
+
+try:
+ from openpi_client import msgpack_numpy
+except ImportError: # pragma: no cover - optional e2e dependency
+ msgpack_numpy = None
+
+PING_INTERVAL_SECS = 300
+PING_TIMEOUT_SECS = 3600
+DEFAULT_HOST = "127.0.0.1"
+DEFAULT_PORT = 8000
+DEFAULT_PATH = "/v1/realtime/robot/openpi"
+DEFAULT_SESSION_ID = "gr00t-smoke"
+ACTION_KEYS = {"eef_9d", "gripper_position", "joint_position"}
+LANGUAGE_KEY = "annotation.language.language_instruction"
+
+
+def _identity_eef_9d_state() -> np.ndarray:
+ state = np.zeros((1, 1, 9), dtype=np.float32)
+ state[..., 3:] = np.array([1, 0, 0, 0, 1, 0], dtype=np.float32)
+ return state
+
+
+def require_dependencies() -> None:
+ missing = []
+ if websockets_client is None:
+ missing.append("websockets")
+ if msgpack_numpy is None:
+ missing.append("openpi-client")
+ if missing:
+ raise ModuleNotFoundError(f"GR00T OpenPI test dependencies are missing: {', '.join(missing)}")
+
+
+@dataclass(frozen=True)
+class Gr00tServerMetadata:
+ action_horizon: int
+ action_keys: set[str]
+ embodiment_tag: str
+ needs_session_id: bool
+
+ @classmethod
+ def from_dict(cls, payload: dict[str, Any]):
+ required_keys = ("action_horizon", "action_keys", "embodiment_tag", "needs_session_id")
+ missing_keys = [key for key in required_keys if key not in payload]
+ if missing_keys:
+ raise ValueError(f"Missing GR00T metadata keys: {missing_keys}")
+
+ return cls(
+ action_horizon=int(payload["action_horizon"]),
+ action_keys={str(key) for key in payload["action_keys"]},
+ embodiment_tag=str(payload["embodiment_tag"]),
+ needs_session_id=bool(payload["needs_session_id"]),
+ )
+
+
+class OpenPIWebsocketClient:
+ def __init__(
+ self,
+ *,
+ host: str = DEFAULT_HOST,
+ port: int = DEFAULT_PORT,
+ path: str = DEFAULT_PATH,
+ ) -> None:
+ require_dependencies()
+ self._uri = f"ws://{host}:{port}{path}"
+ self._packer = msgpack_numpy.Packer()
+ self._ws, self._server_metadata = self._connect()
+
+ def _connect(self):
+ conn = websockets_client.connect(
+ self._uri,
+ compression=None,
+ max_size=None,
+ ping_interval=PING_INTERVAL_SECS,
+ ping_timeout=PING_TIMEOUT_SECS,
+ )
+ metadata = msgpack_numpy.unpackb(conn.recv())
+ if not isinstance(metadata, dict):
+ raise TypeError(f"Expected dict metadata from server, got {type(metadata)!r}")
+ return conn, metadata
+
+ def get_server_metadata(self) -> dict[str, Any]:
+ return dict(self._server_metadata)
+
+ def infer(self, obs: dict[str, Any]) -> dict[str, np.ndarray]:
+ payload = dict(obs)
+ payload["endpoint"] = "infer"
+ self._ws.send(self._packer.pack(payload))
+ response = msgpack_numpy.unpackb(self._ws.recv())
+ if isinstance(response, dict) and response.get("type") == "error":
+ raise RuntimeError(f"Inference failed: {response['message']}")
+ if not isinstance(response, dict):
+ raise RuntimeError(f"Expected dict actions from GR00T OpenPI endpoint, got {type(response)!r}")
+ return {str(key): np.asarray(value, dtype=np.float32) for key, value in response.items()}
+
+ def reset(self, reset_info: dict[str, Any] | None = None) -> str:
+ payload = dict(reset_info or {})
+ payload["endpoint"] = "reset"
+ self._ws.send(self._packer.pack(payload))
+ response = msgpack_numpy.unpackb(self._ws.recv())
+ if not isinstance(response, dict) or response.get("status") != "reset successful":
+ raise RuntimeError(f"Unexpected reset response: {response!r}")
+ return str(response["status"])
+
+ def close(self) -> None:
+ self._ws.close()
+
+
+def build_droid_observation(*, session_id: str = DEFAULT_SESSION_ID) -> dict[str, Any]:
+ return {
+ "session_id": session_id,
+ "video": {
+ "exterior_image_1_left": np.zeros((1, 2, 256, 256, 3), dtype=np.uint8),
+ "wrist_image_left": np.zeros((1, 2, 256, 256, 3), dtype=np.uint8),
+ },
+ "state": {
+ "eef_9d": _identity_eef_9d_state(),
+ "gripper_position": np.zeros((1, 1, 1), dtype=np.float32),
+ "joint_position": np.zeros((1, 1, 7), dtype=np.float32),
+ },
+ "language": {LANGUAGE_KEY: [["pick up the object"]]},
+ }
+
+
+def run_policy_session(
+ *,
+ host: str = DEFAULT_HOST,
+ port: int = DEFAULT_PORT,
+ path: str = DEFAULT_PATH,
+ session_id: str = DEFAULT_SESSION_ID,
+) -> dict[str, Any]:
+ client = OpenPIWebsocketClient(host=host, port=port, path=path)
+ try:
+ metadata = client.get_server_metadata()
+ actions = client.infer(build_droid_observation(session_id=session_id))
+ reset_status = client.reset({})
+ return {
+ "metadata": metadata,
+ "actions": actions,
+ "reset_status": reset_status,
+ "session_id": session_id,
+ }
+ finally:
+ client.close()
+
+
+def validate_session_result(result: dict[str, Any]) -> None:
+ metadata = Gr00tServerMetadata.from_dict(result["metadata"])
+ if metadata.embodiment_tag != "OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT":
+ raise AssertionError(f"Unexpected embodiment_tag: {metadata.embodiment_tag}")
+ if not metadata.needs_session_id:
+ raise AssertionError("GR00T test expects needs_session_id metadata")
+ if metadata.action_keys != ACTION_KEYS:
+ raise AssertionError(f"Unexpected action keys: {metadata.action_keys}")
+ if result["reset_status"] != "reset successful":
+ raise AssertionError(f"Unexpected reset status: {result['reset_status']!r}")
+
+ actions = result["actions"]
+ if set(actions) != ACTION_KEYS:
+ raise AssertionError(f"Unexpected action keys: {set(actions)}")
+ expected_shapes = {
+ "eef_9d": (1, metadata.action_horizon, 9),
+ "gripper_position": (1, metadata.action_horizon, 1),
+ "joint_position": (1, metadata.action_horizon, 7),
+ }
+ for key, expected_shape in expected_shapes.items():
+ if actions[key].shape != expected_shape:
+ raise AssertionError(f"Action {key} shape mismatch: expected {expected_shape}, got {actions[key].shape}")
+ if actions[key].dtype != np.float32:
+ raise AssertionError(f"Action {key} dtype mismatch: expected float32, got {actions[key].dtype}")
+ if not np.isfinite(actions[key]).all():
+ raise AssertionError(f"Action {key} contains non-finite values")
diff --git a/vllm_omni/config/pipeline_registry.py b/vllm_omni/config/pipeline_registry.py
index 140bddf45ae..c375cc38d60 100644
--- a/vllm_omni/config/pipeline_registry.py
+++ b/vllm_omni/config/pipeline_registry.py
@@ -28,8 +28,6 @@
the entries declared here.
"""
-from __future__ import annotations
-
# --- Multi-stage omni pipelines (LLM-centric; audio / video I/O) ---
_OMNI_PIPELINES: dict[str, tuple[str, str]] = {
# model_type -> (module_path, variable_name)
@@ -69,6 +67,10 @@
"vllm_omni.model_executor.models.glm_image.pipeline",
"GLM_IMAGE_PIPELINE",
),
+ "Gr00tN1d7": (
+ "vllm_omni.model_executor.models.gr00t.pipeline",
+ "GR00T_N1D7_PIPELINE",
+ ),
"hunyuan_image_3_moe": (
"vllm_omni.model_executor.models.hunyuan_image3.pipeline",
"HUNYUAN_IMAGE3_PIPELINE",
diff --git a/vllm_omni/deploy/Gr00tN1d7.yaml b/vllm_omni/deploy/Gr00tN1d7.yaml
new file mode 100644
index 00000000000..8ebb28d5163
--- /dev/null
+++ b/vllm_omni/deploy/Gr00tN1d7.yaml
@@ -0,0 +1,40 @@
+# GR00T N1.7 deploy: single diffusion stage.
+#
+# Topology is declared in vllm_omni/model_executor/models/gr00t/pipeline.py.
+# This default uses one GPU with TP=1 and CFG parallel disabled.
+
+pipeline: Gr00tN1d7
+async_chunk: false
+distributed_executor_backend: mp
+dtype: bfloat16
+
+stages:
+ - stage_id: 0
+ devices: "0"
+ max_num_seqs: 1
+ enforce_eager: true
+ model_class_name: Gr00tN1d7Pipeline
+ model_config:
+ embodiment_tag: OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT
+ strict: true
+ policy_server_config:
+ image_resolution: [180, 320]
+ needs_session_id: true
+ embodiment_tag: OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT
+ action_horizon: 40
+ action_keys:
+ - eef_9d
+ - gripper_position
+ - joint_position
+ supported_embodiments:
+ - OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT
+ - XDOF
+ - XDOF_SUBTASK
+ - REAL_G1
+ - REAL_R1_PRO_SHARPA
+ - REAL_R1_PRO_SHARPA_HUMAN
+ - REAL_R1_PRO_SHARPA_MAXINSIGHTS
+ - REAL_R1_PRO_SHARPA_MECKA
+ - LIBERO_PANDA
+ - SIMPLER_ENV_GOOGLE
+ - SIMPLER_ENV_WIDOWX
diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py
index 5fe2d4061cc..01026f35ebd 100644
--- a/vllm_omni/diffusion/data.py
+++ b/vllm_omni/diffusion/data.py
@@ -935,6 +935,10 @@ def enrich_config(self) -> None:
self.model_class_name = "WanS2VPipeline"
self.tf_model_config = TransformerConfig()
self.update_multimodal_support()
+ elif model_type == "Gr00tN1d7" or "Gr00tN1d7" in architectures:
+ self.model_class_name = "Gr00tN1d7Pipeline"
+ self.set_tf_model_config(TransformerConfig())
+ self.update_multimodal_support()
elif architectures and len(architectures) == 1:
self.model_class_name = architectures[0]
else:
@@ -1006,10 +1010,10 @@ class DiffusionOutput:
"""
# Fields may be replaced with SHM handle dicts by ipc.pack_diffusion_output_shm
- output: torch.Tensor | dict | None = None
- trajectory_timesteps: torch.Tensor | dict | None = None
- trajectory_latents: torch.Tensor | dict | None = None
- trajectory_log_probs: torch.Tensor | dict | None = None
+ output: torch.Tensor | tuple[Any, ...] | dict[str, Any] | None = None
+ trajectory_timesteps: torch.Tensor | dict[str, Any] | None = None
+ trajectory_latents: torch.Tensor | dict[str, Any] | None = None
+ trajectory_log_probs: torch.Tensor | dict[str, Any] | None = None
trajectory_decoded: list[Image.Image] | None = None
error: str | None = None
aborted: bool = False
@@ -1017,6 +1021,11 @@ class DiffusionOutput:
post_process_func: Callable[..., Any] | None = None
+ # Multimodal payloads that should be surfaced on OmniRequestOutput.
+ # Robot policies use this for `{"actions": ...}` without pretending the
+ # action dict is an image/video output.
+ multimodal_output: dict[str, Any] = field(default_factory=dict)
+
# Extra custom output data (e.g. latent trajectories, prompt embeds)
# passed through to OmniRequestOutput.custom_output
custom_output: dict[str, Any] = field(default_factory=dict)
diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py
index 1805b5dc349..d255f8b22d9 100644
--- a/vllm_omni/diffusion/diffusion_engine.py
+++ b/vllm_omni/diffusion/diffusion_engine.py
@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from __future__ import annotations
-
import asyncio
import concurrent.futures
import inspect
@@ -220,7 +218,8 @@ async def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
raise RuntimeError(output.error)
logger.debug("Generation completed successfully.")
- if output.output is None:
+ base_multimodal_output = output.multimodal_output
+ if output.output is None and not base_multimodal_output:
logger.warning("Output is None, returning empty OmniRequestOutput")
return [
OmniRequestOutput.from_diffusion(
@@ -258,12 +257,18 @@ async def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
custom_output = output.custom_output or {}
model_audio_sample_rate = None
model_fps = None
+ non_visual_output_keys = {"audio", "actions", "custom_output", "audio_sample_rate", "fps"}
if isinstance(outputs, dict):
audio_payload = outputs.get("audio")
+ if "actions" in outputs:
+ base_multimodal_output["actions"] = outputs["actions"]
custom_output.update(outputs.get("custom_output") or {})
model_audio_sample_rate = outputs.get("audio_sample_rate")
model_fps = outputs.get("fps")
- outputs = outputs.get("video", outputs)
+ if "video" in outputs:
+ outputs = outputs["video"]
+ elif outputs.keys() <= non_visual_output_keys:
+ outputs = None
postprocess_time = time.perf_counter() - postprocess_start_time
logger.debug("Post-processing completed in %.4f seconds", postprocess_time)
@@ -346,13 +351,14 @@ def _audio_mm(payload: Any) -> dict[str, Any]:
),
]
else:
- mm_output = {}
+ mm_output = base_multimodal_output.copy()
if audio_payload is not None:
mm_output["audio"] = audio_payload
if model_audio_sample_rate is not None:
mm_output["audio_sample_rate"] = model_audio_sample_rate
if model_fps is not None:
mm_output["fps"] = model_fps
+ final_output_type = "actions" if "actions" in mm_output and not outputs else "image"
return [
OmniRequestOutput.from_diffusion(
request_id=request_id,
@@ -366,6 +372,7 @@ def _audio_mm(payload: Any) -> dict[str, Any]:
trajectory_decoded=output.trajectory_decoded,
custom_output=custom_output,
multimodal_output=mm_output,
+ final_output_type=final_output_type,
stage_durations=output.stage_durations,
peak_memory_mb=output.peak_memory_mb,
),
@@ -405,7 +412,7 @@ def _audio_mm(payload: Any) -> dict[str, Any]:
),
)
else:
- mm_output = {}
+ mm_output = base_multimodal_output.copy()
if audio_payload is not None:
sliced_audio = audio_payload
if isinstance(audio_payload, (list, tuple)):
@@ -422,6 +429,7 @@ def _audio_mm(payload: Any) -> dict[str, Any]:
mm_output["audio_sample_rate"] = model_audio_sample_rate
if model_fps is not None:
mm_output["fps"] = model_fps
+ final_output_type = "actions" if "actions" in mm_output and not request_outputs else "image"
results.append(
OmniRequestOutput.from_diffusion(
request_id=request_id,
@@ -435,6 +443,7 @@ def _audio_mm(payload: Any) -> dict[str, Any]:
trajectory_decoded=output.trajectory_decoded,
custom_output=custom_output,
multimodal_output=mm_output,
+ final_output_type=final_output_type,
stage_durations=output.stage_durations,
peak_memory_mb=output.peak_memory_mb,
),
@@ -569,7 +578,7 @@ def _handle_finished_requests(
def make_engine(
config: OmniDiffusionConfig,
scheduler: SchedulerInterface | None = None,
- ) -> DiffusionEngine:
+ ) -> "DiffusionEngine":
"""Factory method to create a DiffusionEngine instance.
Args:
diff --git a/vllm_omni/diffusion/models/gr00t/__init__.py b/vllm_omni/diffusion/models/gr00t/__init__.py
new file mode 100644
index 00000000000..a10202a40bc
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/__init__.py
@@ -0,0 +1,6 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from vllm_omni.diffusion.models.gr00t.pipeline_gr00t import Gr00tN1d7Pipeline
+
+__all__ = ["Gr00tN1d7Pipeline"]
diff --git a/vllm_omni/diffusion/models/gr00t/configs/__init__.py b/vllm_omni/diffusion/models/gr00t/configs/__init__.py
new file mode 100644
index 00000000000..9648b859751
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/configs/__init__.py
@@ -0,0 +1,8 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+MODEL_CONFIG_TYPES: dict[str, type] = {}
+
+
+def register_model_config(shortname: str, configtype: type) -> None:
+ MODEL_CONFIG_TYPES[shortname] = configtype
diff --git a/vllm_omni/diffusion/models/gr00t/configs/embodiment_configs.py b/vllm_omni/diffusion/models/gr00t/configs/embodiment_configs.py
new file mode 100644
index 00000000000..7d11a651cf6
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/configs/embodiment_configs.py
@@ -0,0 +1,12 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from vllm_omni.diffusion.models.gr00t.dataio.types import (
+ ActionConfig,
+ ActionFormat,
+ ActionRepresentation,
+ ActionType,
+ ModalityConfig,
+)
+
+__all__ = ["ActionConfig", "ActionFormat", "ActionRepresentation", "ActionType", "ModalityConfig"]
diff --git a/vllm_omni/diffusion/models/gr00t/configs/gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/configs/gr00t_n1d7.py
new file mode 100644
index 00000000000..d7d4527fb93
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/configs/gr00t_n1d7.py
@@ -0,0 +1,109 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from transformers import PretrainedConfig
+
+from . import register_model_config
+
+
+def _default_diffusion_model_cfg() -> dict:
+ return {
+ "positional_embeddings": None,
+ "num_layers": 16,
+ "num_attention_heads": 32,
+ "attention_head_dim": 48,
+ "norm_type": "ada_norm",
+ "output_dim": 1024,
+ "interleave_self_attention": True,
+ }
+
+
+class Gr00tN1d7Config(PretrainedConfig):
+ model_type = "Gr00tN1d7"
+
+ def __init__(
+ self,
+ model_dtype: str = "bfloat16",
+ model_name: str = "nvidia/Cosmos-Reason2-2B",
+ backbone_model_type: str = "qwen",
+ model_revision: str | None = None,
+ backbone_embedding_dim: int = 2048,
+ select_layer: int = 12,
+ reproject_vision: bool = False,
+ use_flash_attention: bool = True,
+ load_bf16: bool = False,
+ image_crop_size: tuple | None = (230, 230),
+ image_target_size: tuple | None = (256, 256),
+ shortest_image_edge: int | None = None,
+ crop_fraction: float | None = None,
+ random_rotation_angle: int | None = None,
+ color_jitter_params: dict | None = None,
+ formalize_language: bool = True,
+ apply_sincos_state_encoding: bool = False,
+ use_percentiles: bool = True,
+ use_relative_action: bool = False,
+ max_state_dim: int = 132,
+ max_action_dim: int = 132,
+ action_horizon: int = 40,
+ hidden_size: int = 1024,
+ input_embedding_dim: int = 1536,
+ state_history_length: int = 1,
+ add_pos_embed: bool = True,
+ attn_dropout: float = 0.2,
+ use_vlln: bool = True,
+ max_seq_len: int = 1024,
+ use_alternate_vl_dit: bool = True,
+ attend_text_every_n_blocks: int = 2,
+ diffusion_model_cfg: dict | None = None,
+ vl_self_attention_cfg: dict | None = None,
+ num_inference_timesteps: int = 4,
+ num_timestep_buckets: int = 1000,
+ exclude_state: bool = False,
+ use_mean_std: bool = False,
+ max_num_embodiments: int = 32,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.model_dtype = model_dtype
+ self.model_name = model_name
+ self.backbone_model_type = backbone_model_type
+ self.model_revision = model_revision
+ self.backbone_embedding_dim = backbone_embedding_dim
+ self.select_layer = select_layer
+ self.reproject_vision = reproject_vision
+ self.use_flash_attention = use_flash_attention
+ self.load_bf16 = load_bf16
+ self.image_crop_size = image_crop_size
+ self.image_target_size = image_target_size
+ self.shortest_image_edge = shortest_image_edge
+ self.crop_fraction = crop_fraction
+ self.random_rotation_angle = random_rotation_angle
+ self.color_jitter_params = color_jitter_params
+ self.formalize_language = formalize_language
+ self.apply_sincos_state_encoding = apply_sincos_state_encoding
+ self.use_percentiles = use_percentiles
+ self.use_relative_action = use_relative_action
+ self.max_state_dim = max_state_dim
+ self.max_action_dim = max_action_dim
+ self.action_horizon = action_horizon
+ self.hidden_size = hidden_size
+ self.input_embedding_dim = input_embedding_dim
+ self.state_history_length = state_history_length
+ self.add_pos_embed = add_pos_embed
+ self.attn_dropout = attn_dropout
+ self.use_vlln = use_vlln
+ self.max_seq_len = max_seq_len
+ self.use_alternate_vl_dit = use_alternate_vl_dit
+ self.attend_text_every_n_blocks = attend_text_every_n_blocks
+ self.diffusion_model_cfg = (
+ diffusion_model_cfg if diffusion_model_cfg is not None else _default_diffusion_model_cfg()
+ )
+ self.vl_self_attention_cfg = vl_self_attention_cfg
+ self.num_inference_timesteps = num_inference_timesteps
+ self.num_timestep_buckets = num_timestep_buckets
+ self.exclude_state = exclude_state
+ self.use_mean_std = use_mean_std
+ self.max_num_embodiments = max_num_embodiments
+
+
+register_model_config("Gr00tN1d7", Gr00tN1d7Config)
diff --git a/vllm_omni/diffusion/models/gr00t/dataio/__init__.py b/vllm_omni/diffusion/models/gr00t/dataio/__init__.py
new file mode 100644
index 00000000000..208f01a7cb5
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/dataio/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
diff --git a/vllm_omni/diffusion/models/gr00t/dataio/embodiment_tags.py b/vllm_omni/diffusion/models/gr00t/dataio/embodiment_tags.py
new file mode 100755
index 00000000000..3ed0c9b0360
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/dataio/embodiment_tags.py
@@ -0,0 +1,191 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from enum import Enum
+
+"""
+Embodiment tags are used to identify the robot embodiment in the data.
+
+Naming convention:
+_
+
+If using multiple datasets, e.g. sim GR1 and real GR1, we can drop the dataset name and use only the robot name.
+"""
+
+
+class EmbodimentTag(Enum):
+ """Embodiment tags supported by the GR00T N1.7 checkpoint.
+
+ Pretrain tags (baked into the base model nvidia/GR00T-N1.7-3B, inference-ready):
+ - OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT -> "oxe_droid_relative_eef_relative_joint"
+ - XDOF -> "xdof_relative_eef_relative_joint"
+ - XDOF_SUBTASK -> "xdof_relative_eef_relative_joint_subtask"
+ - REAL_G1 -> "real_g1_relative_eef_relative_joints"
+ - REAL_R1_PRO_SHARPA -> "real_r1_pro_sharpa_relative_eef"
+ - REAL_R1_PRO_SHARPA_HUMAN -> "real_r1_pro_sharpa_relative_eef_human"
+ - REAL_R1_PRO_SHARPA_MAXINSIGHTS -> "real_r1_pro_sharpa_relative_eef_maxinsights"
+ - REAL_R1_PRO_SHARPA_MECKA -> "real_r1_pro_sharpa_relative_eef_mecka"
+
+ Pre-registered posttrain tags (require finetuned checkpoint):
+ - UNITREE_G1 -> "unitree_g1_full_body_with_waist_height_nav_cmd"
+ - UNITREE_G1_SONIC -> "unitree_g1_sonic"
+ - SIMPLER_ENV_GOOGLE -> "simpler_env_google"
+ - SIMPLER_ENV_WIDOWX -> "simpler_env_widowx"
+ - LIBERO_PANDA -> "libero_sim"
+
+ Finetuning tag (for custom robots):
+ - NEW_EMBODIMENT -> "new_embodiment"
+
+ Use ``EmbodimentTag.resolve(s)`` to look up a tag by name or value,
+ case-insensitively.
+ """
+
+ OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT = "oxe_droid_relative_eef_relative_joint"
+ """
+ The Open-X-Embodiment DROID robot with relative EEF and relative joint position actions.
+ """
+
+ XDOF = "xdof_relative_eef_relative_joint"
+ """
+ The generic X-DOF robot with relative EEF and relative joint position actions.
+ """
+
+ XDOF_SUBTASK = "xdof_relative_eef_relative_joint_subtask"
+ """
+ The generic X-DOF robot (subtask variant).
+ """
+
+ REAL_G1 = "real_g1_relative_eef_relative_joints"
+ """
+ Real-world Unitree G1 with relative EEF and relative joint actions.
+ """
+
+ REAL_R1_PRO_SHARPA = "real_r1_pro_sharpa_relative_eef"
+ """
+ Real-world R1 Pro Sharpa with relative EEF actions.
+ """
+
+ REAL_R1_PRO_SHARPA_HUMAN = "real_r1_pro_sharpa_relative_eef_human"
+ """
+ Real-world R1 Pro Sharpa with relative EEF actions (human teleop data).
+ """
+
+ REAL_R1_PRO_SHARPA_MAXINSIGHTS = "real_r1_pro_sharpa_relative_eef_maxinsights"
+ """
+ Real-world R1 Pro Sharpa with relative EEF actions (MaxInsights data, single-cam).
+ """
+
+ REAL_R1_PRO_SHARPA_MECKA = "real_r1_pro_sharpa_relative_eef_mecka"
+ """
+ Real-world R1 Pro Sharpa with relative EEF actions (Mecka data, single-cam).
+ """
+
+ UNITREE_G1 = "unitree_g1_full_body_with_waist_height_nav_cmd"
+ """
+ The Unitree G1 robot (sim, full-body with waist height and nav commands).
+ """
+
+ UNITREE_G1_SONIC = "unitree_g1_sonic"
+ """
+ The Unitree G1 robot with SONIC whole-body controller. VLA action space is SONIC latents.
+ """
+
+ SIMPLER_ENV_GOOGLE = "simpler_env_google"
+ """
+ The SimplerEnv Google robot.
+ """
+
+ SIMPLER_ENV_WIDOWX = "simpler_env_widowx"
+ """
+ The SimplerEnv WidowX robot.
+ """
+
+ LIBERO_PANDA = "libero_sim"
+ """
+ The LIBERO Panda robot (used for LIBERO-Goal, LIBERO-Object, LIBERO-Spatial, LIBERO-10).
+ """
+
+ # New embodiment during post-training
+ NEW_EMBODIMENT = "new_embodiment"
+ """
+ Any new embodiment.
+ """
+
+ @classmethod
+ def resolve(cls, tag: "str | EmbodimentTag") -> "EmbodimentTag":
+ """Resolve a string to an EmbodimentTag, case-insensitively.
+
+ Matches by enum **name** first (e.g. ``"xdof"`` -> ``XDOF``), then by
+ enum **value** (e.g. ``"xdof_relative_eef_relative_joint"`` -> ``XDOF``).
+
+ Raises:
+ ValueError: If *tag* does not match any known embodiment.
+ """
+ if isinstance(tag, cls):
+ return tag
+ key = tag.strip()
+ key_lower = key.lower()
+ # Match by enum name (case-insensitive)
+ for member in cls:
+ if member.name.lower() == key_lower:
+ return member
+ # Match by enum value (case-insensitive)
+ for member in cls:
+ if member.value.lower() == key_lower:
+ return member
+
+ def _fmt(tags):
+ return "\n".join(f" {m.name:40s} -> {m.value}" for m in tags)
+
+ msg = (
+ f"Unknown embodiment tag: {tag!r}\n\n"
+ f" Base model tags (work with nvidia/GR00T-N1.7-3B):\n"
+ f"{_fmt(PRETRAIN_TAGS)}\n\n"
+ f" Posttrain tags (require a finetuned checkpoint):\n"
+ f"{_fmt(POSTTRAIN_TAGS)}\n\n"
+ f" Finetuning-only tags (for custom robots):\n"
+ f"{_fmt(FINETUNE_ONLY_TAGS)}"
+ )
+ raise ValueError(msg)
+
+ @classmethod
+ def reverse_lookup(cls, value: str) -> "str":
+ """Map a tag value string back to its enum name, or return the value as-is."""
+ for member in cls:
+ if member.value == value:
+ return member.name
+ return value
+
+
+# Module-level tag category sets (cannot be Enum class attributes).
+PRETRAIN_TAGS: frozenset[EmbodimentTag] = frozenset(
+ {
+ EmbodimentTag.OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT,
+ EmbodimentTag.XDOF,
+ EmbodimentTag.XDOF_SUBTASK,
+ EmbodimentTag.REAL_G1,
+ EmbodimentTag.REAL_R1_PRO_SHARPA,
+ EmbodimentTag.REAL_R1_PRO_SHARPA_HUMAN,
+ EmbodimentTag.REAL_R1_PRO_SHARPA_MAXINSIGHTS,
+ EmbodimentTag.REAL_R1_PRO_SHARPA_MECKA,
+ }
+)
+"""Tags baked into the base model (nvidia/GR00T-N1.7-3B) — usable without finetuning."""
+
+POSTTRAIN_TAGS: frozenset[EmbodimentTag] = frozenset(
+ {
+ EmbodimentTag.UNITREE_G1,
+ EmbodimentTag.UNITREE_G1_SONIC,
+ EmbodimentTag.SIMPLER_ENV_GOOGLE,
+ EmbodimentTag.SIMPLER_ENV_WIDOWX,
+ EmbodimentTag.LIBERO_PANDA,
+ }
+)
+"""Tags that require a finetuned checkpoint."""
+
+FINETUNE_ONLY_TAGS: frozenset[EmbodimentTag] = frozenset(
+ {
+ EmbodimentTag.NEW_EMBODIMENT,
+ }
+)
+"""Tags for custom robots (finetuning only, not in any shipped checkpoint)."""
diff --git a/vllm_omni/diffusion/models/gr00t/dataio/state_action/__init__.py b/vllm_omni/diffusion/models/gr00t/dataio/state_action/__init__.py
new file mode 100644
index 00000000000..208f01a7cb5
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/dataio/state_action/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
diff --git a/vllm_omni/diffusion/models/gr00t/dataio/state_action/action_chunking.py b/vllm_omni/diffusion/models/gr00t/dataio/state_action/action_chunking.py
new file mode 100644
index 00000000000..4128d6cae6c
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/dataio/state_action/action_chunking.py
@@ -0,0 +1,654 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from collections.abc import Sequence
+from typing import Generic, TypeVar
+
+import numpy as np
+from numpy.typing import NDArray
+from scipy import interpolate
+from scipy.spatial.transform import Rotation, Slerp
+
+from vllm_omni.diffusion.models.gr00t.dataio.state_action.pose import EndEffectorPose, JointPose, Pose
+from vllm_omni.diffusion.models.gr00t.dataio.types import ActionFormat
+
+PoseType = TypeVar("PoseType", bound=Pose)
+
+
+class ActionChunk(Generic[PoseType]):
+ """
+ Abstract base class for robot action chunking.
+
+ This class provides common functionality for different action chunking types
+ including relative and delta action chunking computation with optional reference frames,
+ interpolation, and format conversion.
+ """
+
+ def __init__(
+ self,
+ poses: Sequence[PoseType],
+ times: Sequence[float] | NDArray[np.float64] | None = None,
+ ):
+ """
+ Initialize action chunking from a list of poses.
+
+ Args:
+ poses: Sequence of Pose objects
+ times: Optional sequence of timestamps for each pose. If None, assumes
+ uniform spacing starting from 0 with step 1.0
+
+ Raises:
+ ValueError: If action chunking is empty or times length doesn't match poses
+ """
+ if not poses:
+ raise ValueError("ActionChunk must contain at least one pose")
+
+ self._poses: list[PoseType] = list(poses)
+
+ # Set up times
+ if times is None:
+ self._times = np.arange(len(poses), dtype=np.float64)
+ else:
+ if len(times) != len(poses):
+ raise ValueError("Number of times must match number of poses")
+ self._times = np.array(times, dtype=np.float64)
+
+ @property
+ def poses(self) -> list[PoseType]:
+ """Get the list of poses"""
+ return self._poses.copy()
+
+ @property
+ def times(self) -> NDArray[np.float64]:
+ """Get the timestamps"""
+ return self._times.copy()
+
+ @property
+ def num_poses(self) -> int:
+ """Get the number of poses in the action chunking"""
+ return len(self._poses)
+
+ def relative_chunking(self, reference_frame: PoseType | None = None) -> "ActionChunk[PoseType]":
+ """
+ Compute the relative action chunking with respect to a reference frame.
+
+ If reference_frame is None, uses the first pose in the action chunking as reference.
+ All poses are transformed to be relative to the reference frame.
+
+ Args:
+ reference_frame: Optional reference pose. If None, uses first pose.
+
+ Returns:
+ A new ActionChunk of the same type where all poses are relative to the reference frame.
+ """
+ if not self._poses:
+ return self.__class__([], times=[])
+
+ # Use the first pose as the reference if one is not provided.
+ ref_pose = reference_frame if reference_frame is not None else self._poses[0]
+
+ # Use the polymorphic subtraction defined in the Pose subclasses.
+ # The subtraction returns the same type as the operands
+ relative_poses: list[PoseType] = [pose - ref_pose for pose in self._poses] # type: ignore[misc]
+
+ # Return a new instance of the same action chunking class
+ # (e.g., JointActionChunk or EndEffectorActionChunk)
+ return self.__class__(relative_poses, times=self.times)
+
+ def delta_chunking(self, reference_frame: PoseType | None = None) -> "ActionChunk[PoseType]":
+ """
+ Compute the delta action chunking where each pose represents the relative
+ transformation from the previous frame.
+
+ If reference_frame is provided, it is treated as the first frame, and the
+ first delta will be from reference_frame to the first pose in the action chunking.
+ Otherwise, the first pose in the delta action chunking will be the identity/zero transformation.
+
+ Args:
+ reference_frame: Optional reference pose to use as the first frame.
+
+ Returns:
+ A new ActionChunk of the same type where each pose is relative to the previous pose.
+ """
+ if not self._poses:
+ return self.__class__([], times=[])
+
+ delta_poses: list[PoseType] = []
+
+ # Determine the initial reference for the very first pose.
+ # If a reference_frame is given, the first delta is pose[0] - reference_frame.
+ # If not, the first delta is pose[0] - pose[0], resulting in an identity/zero pose.
+ prev_pose = reference_frame if reference_frame is not None else self._poses[0]
+
+ for current_pose in self._poses:
+ delta: PoseType = current_pose - prev_pose # type: ignore[assignment]
+ delta_poses.append(delta)
+ prev_pose = current_pose # Update the reference for the next step
+
+ return self.__class__(delta_poses, times=self.times.tolist())
+
+ def to_absolute_chunking(self, reference_frame: PoseType) -> "ActionChunk[PoseType]":
+ """
+ Convert a relative action chunking to an absolute action chunking by applying
+ the relative poses on top of a reference frame.
+
+ This is the inverse operation of relative_chunking(). Each relative pose
+ is composed with the reference frame to produce absolute poses.
+
+ Args:
+ reference_frame: The reference pose to apply the relative action chunking on top of.
+
+ Returns:
+ A new ActionChunk of the same type with absolute poses.
+
+ Raises:
+ NotImplementedError: Must be implemented by subclasses
+ """
+ raise NotImplementedError("Subclasses must implement to_absolute_chunking")
+
+ def interpolate(
+ self,
+ num_points: int | None = None,
+ times: NDArray[np.float64] | None = None,
+ ) -> "ActionChunk":
+ """
+ Interpolate the action chunking to generate intermediate poses.
+ Must be implemented by subclasses.
+
+ Args:
+ num_points: Number of evenly-spaced points to generate
+ times: Specific timestamps at which to interpolate
+
+ Returns:
+ A new ActionChunk with interpolated poses
+ """
+ raise NotImplementedError("Subclasses must implement interpolate")
+
+ def to(self, action_format: ActionFormat) -> NDArray[np.float64]:
+ """
+ Convert action chunking to the specified action format.
+ Must be implemented by subclasses.
+
+ Args:
+ action_format: The desired output format
+
+ Returns:
+ Array in the requested format
+
+ Raises:
+ NotImplementedError: If not implemented by subclass
+ """
+ raise NotImplementedError("Subclasses must implement to method")
+
+ def __len__(self) -> int:
+ """Return the number of poses in the action chunking"""
+ return len(self._poses)
+
+ def __getitem__(self, index: int) -> PoseType:
+ """Get a pose by index"""
+ return self._poses[index]
+
+ def __repr__(self) -> str:
+ return (
+ f"{self.__class__.__name__}(num_poses={len(self._poses)}, "
+ f"time_range=[{self._times[0]:.2f}, {self._times[-1]:.2f}])"
+ )
+
+
+class JointActionChunk(ActionChunk[JointPose]):
+ """
+ Represents action chunking in joint space as a sequence of joint configurations.
+
+ Examples:
+ # Create a joint action chunking
+ joint_poses = [
+ JointPose([0.0, 0.0, 0.0]),
+ JointPose([0.5, 0.5, 0.5]),
+ JointPose([1.0, 1.0, 1.0]),
+ ]
+ action_chunking = JointActionChunk(joint_poses)
+
+ # Get relative trajectory (all poses relative to first pose)
+ relative_traj = action_chunking.relative_chunking()
+
+ # Get relative trajectory with custom reference
+ reference = JointPose([0.1, 0.1, 0.1])
+ relative_traj = action_chunking.relative_chunking(reference_frame=reference)
+
+ # Get delta trajectory (incremental changes)
+ delta_traj = action_chunking.delta_chunking()
+
+ # Convert relative trajectory back to absolute
+ reference = JointPose([0.1, 0.1, 0.1])
+ absolute_traj = relative_traj.to_absolute_chunking(reference_frame=reference)
+
+ # Interpolate trajectory
+ interpolated = action_chunking.interpolate(num_points=10)
+
+ # Convert to desired format
+ from vllm_omni.diffusion.models.gr00t.dataio.types import ActionFormat
+ array_data = action_chunking.to(ActionFormat.DEFAULT) # Returns joint array
+ """
+
+ def __init__(
+ self,
+ poses: Sequence[JointPose],
+ times: Sequence[float] | NDArray[np.float64] | None = None,
+ ):
+ """
+ Initialize a joint trajectory from a list of joint poses.
+
+ Args:
+ poses: Sequence of JointPose objects
+ times: Optional sequence of timestamps for each pose
+
+ Raises:
+ TypeError: If poses are not all JointPose objects
+ """
+ # Validate all poses are JointPose
+ if not all(isinstance(p, JointPose) for p in poses):
+ raise TypeError("All poses must be JointPose objects for JointActionChunk")
+
+ super().__init__(poses, times)
+
+ def interpolate(
+ self,
+ num_points: int | None = None,
+ times: NDArray[np.float64] | None = None,
+ ) -> "JointActionChunk":
+ """
+ Interpolate the joint action chunking to generate intermediate configurations.
+
+ Uses linear interpolation for joint values.
+
+ Args:
+ num_points: Number of evenly-spaced points to generate (including endpoints).
+ Only used if times is None.
+ times: Specific timestamps at which to interpolate. If provided,
+ num_points is ignored.
+
+ Returns:
+ A new JointActionChunk with interpolated poses
+
+ Raises:
+ ValueError: If neither num_points nor times is provided, or if
+ interpolation times are outside the trajectory range
+ """
+ if num_points is None and times is None:
+ raise ValueError("Must provide either num_points or times")
+
+ if len(self._poses) < 2:
+ raise ValueError("Need at least 2 poses for interpolation")
+
+ # Prepare data: extract joint values
+ timestamps = self._times.copy()
+ joint_values = np.array([pose.joints for pose in self._poses]) # (N, num_joints)
+
+ # Find and remove non-monotonic timestamps
+ drop_indices = [idx for idx in range(1, len(timestamps)) if timestamps[idx] <= timestamps[idx - 1]]
+
+ if drop_indices:
+ for idx in drop_indices:
+ print(
+ f"Dropping timestamp pair - Previous: {timestamps[idx - 1]}, "
+ f"Current: {timestamps[idx]} at index {idx}"
+ )
+ timestamps = np.delete(timestamps, drop_indices)
+ joint_values = np.delete(joint_values, drop_indices, axis=0)
+
+ # Check if we still have enough poses after cleanup
+ if len(timestamps) < 2:
+ raise ValueError("Need at least 2 poses with monotonic timestamps for interpolation")
+
+ # Create interpolator
+ joint_interp = interpolate.interp1d(timestamps, joint_values, kind="linear", axis=0)
+
+ # Generate interpolation times if not provided
+ if times is None:
+ assert num_points is not None # Type narrowing for type checker
+ interp_times = np.linspace(timestamps[0], timestamps[-1], num_points)
+ else:
+ interp_times = np.array(times, dtype=np.float64)
+
+ # Check that interpolation times are within bounds
+ if np.any(interp_times < timestamps[0]) or np.any(interp_times > timestamps[-1]):
+ raise ValueError(f"Interpolation times must be within [{timestamps[0]}, {timestamps[-1]}]")
+
+ # Interpolate joint values
+ interp_joint_values = joint_interp(interp_times)
+
+ # Create interpolated poses
+ joint_names = self._poses[0].joint_names
+ interpolated_poses = [
+ JointPose(joints=interp_joint_values[i], joint_names=joint_names) for i in range(len(interp_times))
+ ]
+
+ return JointActionChunk(interpolated_poses, times=interp_times)
+
+ def to_array(self) -> NDArray[np.float64]:
+ """
+ Convert trajectory to array of joint values.
+
+ Returns:
+ Array with shape (N, num_joints) where N is the number of poses
+ """
+ return np.array([pose.joints for pose in self._poses])
+
+ def to_absolute_chunking(self, reference_frame: JointPose) -> "JointActionChunk":
+ """
+ Convert a relative joint action chunking to an absolute action chunking by adding
+ the relative joint positions to the reference frame.
+
+ This is the inverse operation of relative_chunking(). Each relative joint
+ configuration is added to the reference frame to produce absolute joint positions.
+
+ Args:
+ reference_frame: The reference joint pose to apply the relative trajectory on top of.
+
+ Returns:
+ A new JointActionChunk with absolute joint positions.
+
+ Raises:
+ ValueError: If joint dimensions don't match
+ """
+ if not self._poses:
+ return JointActionChunk([], times=[])
+
+ if len(self._poses[0].joints) != len(reference_frame.joints):
+ raise ValueError(
+ f"Cannot apply relative trajectory: "
+ f"joint dimensions don't match ({len(self._poses[0].joints)} vs {len(reference_frame.joints)})"
+ )
+
+ # Add each relative pose to the reference frame
+ absolute_poses: list[JointPose] = []
+ for relative_pose in self._poses:
+ absolute_joints = reference_frame.joints + relative_pose.joints
+ absolute_pose = JointPose(joints=absolute_joints, joint_names=reference_frame.joint_names)
+ absolute_poses.append(absolute_pose)
+
+ return JointActionChunk(absolute_poses, times=self.times)
+
+ def to(self, action_format: ActionFormat) -> NDArray[np.float64]:
+ """
+ Convert trajectory to the desired format.
+
+ Args:
+ action_format: The desired output format
+
+ Returns:
+ Array in the requested format
+
+ Raises:
+ ValueError: If the action format is not supported for joint trajectories
+ """
+ if action_format == ActionFormat.DEFAULT:
+ return self.to_array()
+ else:
+ raise ValueError(
+ f"ActionFormat {action_format} is not supported for JointActionChunk. "
+ f"Only {ActionFormat.DEFAULT} is supported."
+ )
+
+
+class EndEffectorActionChunk(ActionChunk[EndEffectorPose]):
+ """
+ Represents action chunking in Cartesian space as a sequence of end-effector poses.
+
+ Examples:
+ # Create an end-effector action chunking
+ ee_poses = [
+ EndEffectorPose(translation=[0, 0, 0], rotation=[1, 0, 0, 0],
+ rotation_type="quat", rotation_order="wxyz"),
+ EndEffectorPose(translation=[1, 0, 0], rotation=[0.707, 0, 0, 0.707],
+ rotation_type="quat", rotation_order="wxyz"),
+ EndEffectorPose(translation=[2, 0, 0], rotation=[0, 0, 0, 1],
+ rotation_type="quat", rotation_order="wxyz"),
+ ]
+ action_chunking = EndEffectorActionChunk(ee_poses)
+
+ # Get relative trajectory (all poses relative to first pose)
+ relative_traj = action_chunking.relative_chunking()
+
+ # Get relative trajectory with custom reference frame
+ reference = EndEffectorPose(translation=[0.5, 0, 0], rotation=[1,0,0,0],
+ rotation_type="quat", rotation_order="wxyz")
+ relative_traj = action_chunking.relative_chunking(reference_frame=reference)
+
+ # Get delta trajectory
+ delta_traj = action_chunking.delta_chunking()
+
+ # Convert relative trajectory back to absolute
+ reference = EndEffectorPose(translation=[0.5, 0, 0], rotation=[1,0,0,0],
+ rotation_type="quat", rotation_order="wxyz")
+ absolute_traj = relative_traj.to_absolute_chunking(reference_frame=reference)
+
+ # Interpolate trajectory
+ interpolated = action_chunking.interpolate(num_points=10)
+
+ # Convert to desired format
+ from vllm_omni.diffusion.models.gr00t.dataio.types import ActionFormat
+ homo_matrices = action_chunking.to(ActionFormat.DEFAULT) # (N, 4, 4) homogeneous matrices
+ xyz_rot6d = action_chunking.to(ActionFormat.XYZ_ROT6D) # (N, 9) xyz + rot6d
+ xyz_rotvec = action_chunking.to(ActionFormat.XYZ_ROTVEC) # (N, 6) xyz + rotvec
+ """
+
+ def __init__(
+ self,
+ poses: Sequence[EndEffectorPose],
+ times: Sequence[float] | NDArray[np.float64] | None = None,
+ ):
+ """
+ Initialize an end-effector trajectory from a list of end-effector poses.
+
+ Args:
+ poses: Sequence of EndEffectorPose objects
+ times: Optional sequence of timestamps for each pose
+
+ Raises:
+ TypeError: If poses are not all EndEffectorPose objects
+ """
+ # Validate all poses are EndEffectorPose
+ if not all(isinstance(p, EndEffectorPose) for p in poses):
+ raise TypeError("All poses must be EndEffectorPose objects for EndEffectorActionChunk")
+
+ super().__init__(poses, times)
+
+ @classmethod
+ def from_array(cls, data: np.ndarray, action_format: ActionFormat) -> "EndEffectorActionChunk":
+ """
+ Create an EndEffectorActionChunk from a 2-D array using the specified action format.
+
+ This is the inverse of ``.to(action_format)``.
+
+ Args:
+ data: Array of shape (N, D) where D depends on the action_format.
+ action_format: The format that describes the layout of each row.
+
+ Returns:
+ EndEffectorActionChunk with N poses.
+ """
+ poses = [EndEffectorPose.from_action_format(row, action_format) for row in data]
+ return cls(poses)
+
+ def interpolate(
+ self,
+ num_points: int | None = None,
+ times: NDArray[np.float64] | None = None,
+ ) -> "EndEffectorActionChunk":
+ """
+ Interpolate the action chunking to generate intermediate poses.
+
+ Uses linear interpolation for translation and SLERP (Spherical Linear
+ Interpolation) for rotation.
+
+ Args:
+ num_points: Number of evenly-spaced points to generate (including endpoints).
+ Only used if times is None.
+ times: Specific timestamps at which to interpolate. If provided,
+ num_points is ignored.
+
+ Returns:
+ A new EndEffectorActionChunk with interpolated poses
+
+ Raises:
+ ValueError: If neither num_points nor times is provided, or if
+ interpolation times are outside the trajectory range
+ """
+ if num_points is None and times is None:
+ raise ValueError("Must provide either num_points or times")
+
+ if len(self._poses) < 2:
+ raise ValueError("Need at least 2 poses for interpolation")
+
+ # Prepare data: extract positions and rotations
+ timestamps = self._times.copy()
+ homogeneous_matrices = np.array([pose.homogeneous for pose in self._poses])
+ positions = homogeneous_matrices[:, :3, 3]
+ rotations = Rotation.from_matrix(homogeneous_matrices[:, :3, :3])
+
+ # Find indices where timestamps are not monotonically increasing
+ drop_indices = [idx for idx in range(1, len(timestamps)) if timestamps[idx] <= timestamps[idx - 1]]
+
+ # Remove the problematic timestamps and corresponding data
+ if drop_indices:
+ for idx in drop_indices:
+ print(
+ f"Dropping timestamp pair - Previous: {timestamps[idx - 1]}, "
+ f"Current: {timestamps[idx]} at index {idx}"
+ )
+ timestamps = np.delete(timestamps, drop_indices)
+ positions = np.delete(positions, drop_indices, axis=0)
+ rotations = Rotation.from_matrix(np.delete(homogeneous_matrices[:, :3, :3], drop_indices, axis=0))
+
+ # Check if we still have enough poses after cleanup
+ if len(timestamps) < 2:
+ raise ValueError("Need at least 2 poses with monotonic timestamps for interpolation")
+
+ # Create interpolators
+ pos_interp = interpolate.interp1d(timestamps, positions, kind="linear", axis=0)
+ rot_interp = Slerp(timestamps, rotations)
+
+ # Generate interpolation times if not provided
+ if times is None:
+ assert num_points is not None # Type narrowing for type checker
+ interp_times = np.linspace(timestamps[0], timestamps[-1], num_points)
+ else:
+ interp_times = np.array(times, dtype=np.float64)
+
+ # Check that interpolation times are within bounds
+ if np.any(interp_times < timestamps[0]) or np.any(interp_times > timestamps[-1]):
+ raise ValueError(f"Interpolation times must be within [{timestamps[0]}, {timestamps[-1]}]")
+
+ # Interpolate positions and rotations
+ interp_positions = pos_interp(interp_times)
+ interp_rotations = rot_interp(interp_times)
+
+ # Create interpolated poses
+ interpolated_poses = []
+ for i in range(len(interp_times)):
+ pose = EndEffectorPose(
+ translation=interp_positions[i],
+ rotation=interp_rotations[i].as_matrix(),
+ rotation_type="matrix",
+ )
+ interpolated_poses.append(pose)
+
+ return EndEffectorActionChunk(interpolated_poses, times=interp_times)
+
+ def to_homogeneous_matrices(self) -> NDArray[np.float64]:
+ """
+ Convert trajectory to array of homogeneous transformation matrices.
+
+ Returns:
+ Array of homogeneous matrices with shape (N, 4, 4) where N is the number of poses
+ """
+ return np.array([pose.homogeneous for pose in self._poses])
+
+ def to_translation_rot6d(self) -> NDArray[np.float64]:
+ """
+ Convert trajectory to array of translations and 6D rotations.
+
+ Returns:
+ Array with shape (N, 9) - 3 for xyz + 6 for rot6d
+ """
+ translations = np.array([pose.translation for pose in self._poses]) # (N, 3)
+ rotations = np.array([pose.rot6d for pose in self._poses]) # (N, 6)
+
+ # Concatenate translation and rotation
+ xyz_rot6d = np.concatenate([translations, rotations], axis=1) # (N, 9)
+
+ return xyz_rot6d
+
+ def to_translation_rotvec(self) -> NDArray[np.float64]:
+ """
+ Convert trajectory to array of translations and rotation vectors.
+
+ Returns:
+ Array with shape (N, 6) - 3 for xyz + 3 for rotvec
+ """
+ translations = np.array([pose.translation for pose in self._poses]) # (N, 3)
+ rotations = np.array([pose.rotvec for pose in self._poses]) # (N, 3)
+
+ # Concatenate translation and rotation
+ xyz_rotvec = np.concatenate([translations, rotations], axis=1) # (N, 6)
+
+ return xyz_rotvec
+
+ def to_absolute_chunking(self, reference_frame: EndEffectorPose) -> "EndEffectorActionChunk":
+ """
+ Convert a relative end-effector action chunking to an absolute action chunking by
+ composing each relative transformation with the reference frame.
+
+ This is the inverse operation of relative_chunking(). Each relative pose
+ represents a transformation that is applied on top of the reference frame
+ to produce absolute poses.
+
+ Args:
+ reference_frame: The reference end-effector pose to apply the relative trajectory on top of.
+
+ Returns:
+ A new EndEffectorActionChunk with absolute poses.
+ """
+ if not self._poses:
+ return EndEffectorActionChunk([], times=[])
+
+ # Get reference frame as homogeneous matrix
+ T_ref = reference_frame.homogeneous
+
+ # Compose each relative transformation with the reference frame
+ absolute_poses: list[EndEffectorPose] = []
+ for relative_pose in self._poses:
+ # Get relative transformation as homogeneous matrix
+ T_relative = relative_pose.homogeneous
+
+ # Compose transformations: T_absolute = T_ref @ T_relative
+ T_absolute = T_ref @ T_relative
+
+ # Create absolute pose from composed transformation
+ absolute_pose = EndEffectorPose(homogeneous=T_absolute)
+ absolute_poses.append(absolute_pose)
+
+ return EndEffectorActionChunk(absolute_poses, times=self.times)
+
+ def to(self, action_format: ActionFormat) -> NDArray[np.float64]:
+ """
+ Convert trajectory to the desired format.
+
+ Args:
+ action_format: The desired output format
+
+ Returns:
+ Array in the requested format
+
+ Raises:
+ ValueError: If the action format is not supported
+ """
+ if action_format == ActionFormat.DEFAULT:
+ return self.to_homogeneous_matrices()
+ elif action_format == ActionFormat.XYZ_ROT6D:
+ return self.to_translation_rot6d()
+ elif action_format == ActionFormat.XYZ_ROTVEC:
+ return self.to_translation_rotvec()
+ else:
+ raise ValueError(f"Unsupported action format: {action_format}")
diff --git a/vllm_omni/diffusion/models/gr00t/dataio/state_action/pose.py b/vllm_omni/diffusion/models/gr00t/dataio/state_action/pose.py
new file mode 100644
index 00000000000..0e0849f17b4
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/dataio/state_action/pose.py
@@ -0,0 +1,709 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from enum import Enum
+from typing import TypeVar
+
+import numpy as np
+from numpy.typing import NDArray
+from scipy.spatial.transform import Rotation
+
+from vllm_omni.diffusion.models.gr00t.dataio.types import ActionFormat
+
+# TypeVar for self-type preservation in Pose operations
+PoseT = TypeVar("PoseT", bound="Pose")
+
+
+def invert_transformation(transform: NDArray[np.float64]) -> NDArray[np.float64]:
+ """
+ Invert a homogeneous transformation matrix.
+
+ Args:
+ transform: A 4x4 homogeneous transformation matrix
+
+ Returns:
+ The inverse of the transformation matrix (4x4)
+ """
+ R = transform[:3, :3] # Extract the rotation matrix
+ t = transform[:3, 3] # Extract the translation vector
+
+ # Inverse of the rotation matrix is its transpose (since it's orthogonal)
+ R_inv = R.T
+
+ # Inverse of the translation is -R_inv * t
+ t_inv = -R_inv @ t
+
+ # Construct the inverse transformation matrix
+ T_inv = np.eye(4)
+ T_inv[:3, :3] = R_inv
+ T_inv[:3, 3] = t_inv
+
+ return T_inv
+
+
+def relative_transformation(
+ base_transform: NDArray[np.float64], target_transform: NDArray[np.float64]
+) -> NDArray[np.float64]:
+ """
+ Compute the relative transformation between two poses.
+
+ Args:
+ base_transform: Initial 4x4 homogeneous transformation matrix
+ target_transform: Current 4x4 homogeneous transformation matrix
+
+ Returns:
+ The relative transformation matrix (4x4) from base_transform to target_transform
+ """
+ # Relative transformation is base_transform^{-1} * target_transform
+ T_relative = invert_transformation(base_transform) @ target_transform
+ return T_relative
+
+
+class RotationType(Enum):
+ """Supported rotation representation types"""
+
+ QUAT = "quat"
+ EULER = "euler"
+ ROTVEC = "rotvec"
+ MATRIX = "matrix"
+ ROT6D = "rot6d"
+
+
+class EulerOrder(Enum):
+ """Common Euler angle conventions"""
+
+ XYZ = "xyz"
+ ZYX = "zyx"
+ XZY = "xzy"
+ YXZ = "yxz"
+ YZX = "yzx"
+ ZXY = "zxy"
+
+
+class QuatOrder(Enum):
+ """Quaternion ordering conventions"""
+
+ WXYZ = "wxyz" # scalar-first (w, x, y, z)
+ XYZW = "xyzw" # scalar-last (x, y, z, w)
+
+
+class Pose:
+ """
+ Abstract base class for robot poses.
+
+ This class provides common functionality for different pose representations
+ including relative pose computation via the subtraction operator.
+ """
+
+ pose_type: str
+
+ def __sub__(self: PoseT, other: PoseT) -> PoseT:
+ """
+ Compute relative transformation between two poses.
+
+ For EndEffectorPose: Computes the relative transformation from other to self.
+ Result represents the transformation needed to go from other's frame to self's frame.
+
+ For JointPose: Computes the joint-space difference (self - other).
+
+ Args:
+ other: The reference pose to compute relative transformation from
+
+ Returns:
+ Relative pose (same type as self)
+
+ Raises:
+ TypeError: If poses are not of the same type
+
+ Examples:
+ # End-effector poses
+ pose1 = EndEffectorPose(translation=[1, 0, 0], rotation=[1,0,0,0],
+ rotation_type="quat", rotation_order="wxyz")
+ pose2 = EndEffectorPose(translation=[2, 0, 0], rotation=[1,0,0,0],
+ rotation_type="quat", rotation_order="wxyz")
+ relative = pose2 - pose1 # Transformation from pose1 to pose2
+
+ # Joint poses
+ joint1 = JointPose([0.0, 0.5, 1.0])
+ joint2 = JointPose([0.1, 0.6, 1.2])
+ joint_diff = joint2 - joint1 # Joint differences: [0.1, 0.1, 0.2]
+ """
+ if type(self) is not type(other):
+ raise TypeError(
+ f"Cannot compute relative transformation between different pose types: "
+ f"{type(self).__name__} and {type(other).__name__}"
+ )
+
+ return self._compute_relative(other)
+
+ def _compute_relative(self: PoseT, other: PoseT) -> PoseT:
+ """
+ Internal method to compute relative transformation.
+ Must be implemented by subclasses.
+
+ Args:
+ other: The reference pose
+
+ Returns:
+ Relative pose
+ """
+ raise NotImplementedError("Subclasses must implement _compute_relative")
+
+ def copy(self: PoseT) -> PoseT:
+ """
+ Create a deep copy of this pose.
+ Must be implemented by subclasses.
+
+ Returns:
+ New Pose instance with copied data
+ """
+ raise NotImplementedError("Subclasses must implement copy")
+
+
+class JointPose(Pose):
+ """
+ Represents a robot configuration in joint space.
+
+ This class stores joint angles/positions for a robot manipulator.
+ Unlike end-effector poses, joint poses represent the configuration
+ of all joints in the kinematic chain.
+
+ Examples:
+ # Create a 6-DOF joint configuration
+ joint_pose = JointPose(
+ joints=[0.0, -np.pi/4, np.pi/2, 0.0, np.pi/4, 0.0],
+ joint_names=["shoulder_pan", "shoulder_lift", "elbow",
+ "wrist_1", "wrist_2", "wrist_3"]
+ )
+
+ # Create with default joint names
+ joint_pose = JointPose(joints=[0.0, 0.5, 1.0])
+
+ # Get as dictionary
+ joint_dict = joint_pose.to_dict() # {"joint_0": 0.0, ...}
+
+ # Access individual joints
+ first_joint = joint_pose.joints[0]
+ num_joints = joint_pose.num_joints
+
+ # Compute relative joint displacement
+ joint1 = JointPose([0.0, 0.5, 1.0])
+ joint2 = JointPose([0.1, 0.6, 1.2])
+ relative = joint2 - joint1 # [0.1, 0.1, 0.2]
+ """
+
+ pose_type = "joint"
+
+ def __init__(
+ self,
+ joints: list | np.ndarray,
+ joint_names: list | None = None,
+ ):
+ """
+ Initialize a joint pose.
+
+ Args:
+ joints: Joint angles/positions as array-like of shape (n,)
+ joint_names: Optional list of names for each joint. If None,
+ defaults to ["joint_0", "joint_1", ...]
+ """
+ super().__init__()
+ self.joints = np.array(joints, dtype=np.float64)
+
+ # Set defaults and validate joint_names
+ if joint_names is None:
+ self.joint_names = [f"joint_{i}" for i in range(len(self.joints))]
+ else:
+ if len(joint_names) != len(self.joints):
+ raise ValueError(
+ f"Number of joint names ({len(joint_names)}) must match number of joints ({len(self.joints)})"
+ )
+ self.joint_names = joint_names
+
+ @property
+ def num_joints(self) -> int:
+ """
+ Get the number of joints.
+
+ Returns:
+ Number of joints in the configuration
+ """
+ return len(self.joints)
+
+ def to_dict(self) -> dict:
+ """
+ Convert joint configuration to dictionary.
+
+ Returns:
+ Dictionary mapping joint names to joint values
+ """
+ return dict(zip(self.joint_names, self.joints))
+
+ def _compute_relative(self, other): # type: ignore[override]
+ """
+ Compute relative joint displacement.
+
+ Args:
+ other: Reference joint pose
+
+ Returns:
+ JointPose representing the joint-space difference (self - other)
+
+ Raises:
+ ValueError: If joint dimensions don't match
+ """
+ if len(self.joints) != len(other.joints):
+ raise ValueError(
+ f"Cannot compute relative joint pose: "
+ f"joint dimensions don't match ({len(self.joints)} vs {len(other.joints)})"
+ )
+
+ relative_joints = self.joints - other.joints
+ return JointPose(joints=relative_joints, joint_names=self.joint_names)
+
+ def copy(self) -> "JointPose":
+ """
+ Create a deep copy of this joint pose.
+
+ Returns:
+ New JointPose instance with copied data
+ """
+ return JointPose(
+ joints=self.joints.copy(),
+ joint_names=self.joint_names.copy(),
+ )
+
+ def __repr__(self) -> str:
+ if len(self.joints) <= 6:
+ joints_str = np.array2string(self.joints, precision=4, suppress_small=True)
+ else:
+ joints_str = f"[{self.joints[0]:.4f}, ..., {self.joints[-1]:.4f}] ({len(self.joints)} joints)"
+
+ return f"JointPose(joints={joints_str})"
+
+ def __eq__(self, other) -> bool:
+ if not isinstance(other, JointPose):
+ return False
+ return np.allclose(self.joints, other.joints) and self.joint_names == other.joint_names
+
+ def __getitem__(self, index) -> float | NDArray[np.float64]:
+ """Allow indexing: joint_pose[0] returns first joint value"""
+ return self.joints[index]
+
+ def __len__(self) -> int:
+ """Allow len(): len(joint_pose) returns number of joints"""
+ return len(self.joints)
+
+
+class EndEffectorPose(Pose):
+ """
+ Represents a single end-effector pose with translation and rotation components.
+
+ This class handles Cartesian space representations of robot end-effector poses,
+ supporting multiple rotation representations (quaternions, Euler angles, rotation
+ vectors, rotation matrices, etc.).
+
+ Examples:
+ # Create with quaternion (wxyz order)
+ pose = EndEffectorPose(
+ translation=[1.0, 2.0, 3.0],
+ rotation=[1.0, 0.0, 0.0, 0.0],
+ rotation_type="quat",
+ rotation_order="wxyz"
+ )
+
+ # Create with Euler angles (degrees by default)
+ pose = EndEffectorPose(
+ translation=[1, 2, 3],
+ rotation=[0, 0, 90],
+ rotation_type="euler",
+ rotation_order="xyz"
+ )
+
+ # Create with Euler angles in radians
+ pose = EndEffectorPose(
+ translation=[1, 2, 3],
+ rotation=[0, 0, np.pi/2],
+ rotation_type="euler",
+ rotation_order="xyz",
+ degrees=False
+ )
+
+ # Create from homogeneous matrix
+ H = np.eye(4)
+ H[:3, 3] = [1, 2, 3]
+ pose = EndEffectorPose(homogeneous=H)
+
+ # Convert between representations
+ quat_wxyz = pose.to_rotation("quat", "wxyz")
+ euler_zyx = pose.to_rotation("euler", "zyx")
+ rot6d = pose.to_rotation("rot6d")
+
+ # Compute relative transformation
+ pose1 = EndEffectorPose(translation=[1, 0, 0], rotation=[1,0,0,0],
+ rotation_type="quat", rotation_order="wxyz")
+ pose2 = EndEffectorPose(translation=[2, 0, 0], rotation=[1,0,0,0],
+ rotation_type="quat", rotation_order="wxyz")
+ relative = pose2 - pose1 # Transformation from pose1's frame to pose2's frame
+ """
+
+ pose_type = "end_effector"
+
+ def __init__(
+ self,
+ translation: list | np.ndarray | None = None,
+ rotation: list | np.ndarray | None = None,
+ rotation_type: str | None = None,
+ rotation_order: str | None = None,
+ homogeneous: np.ndarray | None = None,
+ degrees: bool = True,
+ ):
+ """
+ Initialize an end-effector pose.
+
+ Args:
+ translation: Translation vector [x, y, z]
+ rotation: Rotation in specified format
+ rotation_type: Type of rotation ("quat", "euler", "rotvec", "matrix", "rot6d")
+ rotation_order: Order/convention for the rotation type
+ homogeneous: Homogeneous transformation matrix (4, 4)
+ If provided, overrides translation and rotation
+ degrees: For Euler angles, whether the input is in degrees (default True)
+ """
+ super().__init__()
+
+ # Cache for homogeneous matrix
+ self._homogeneous_cache: NDArray[np.float64] | None = None
+ self._cache_valid = False
+
+ # Handle homogeneous matrix input
+ if homogeneous is not None:
+ self._from_homogeneous(homogeneous)
+ return
+
+ # Store translation
+ self._translation = np.array(translation) if translation is not None else np.zeros(3)
+
+ # Store rotation as scipy Rotation object internally
+ if rotation is not None:
+ if rotation_type is None:
+ raise ValueError("rotation_type must be specified when rotation is provided")
+ self._set_rotation(rotation, rotation_type, rotation_order, degrees)
+ else:
+ self._rotation = Rotation.identity()
+
+ def _from_homogeneous(self, homogeneous: np.ndarray):
+ """Initialize from homogeneous transformation matrix"""
+ homogeneous = np.array(homogeneous)
+
+ # Extract translation (last column, first 3 rows)
+ self._translation = homogeneous[:3, 3]
+
+ # Extract rotation matrix (top-left 3x3)
+ rotation_matrix = homogeneous[:3, :3]
+
+ # Create Rotation object from matrix
+ self._rotation = Rotation.from_matrix(rotation_matrix)
+
+ @staticmethod
+ def _rot6d_to_matrix(rot6d: np.ndarray) -> np.ndarray:
+ """
+ Convert 6D rotation representation to rotation matrix.
+
+ Args:
+ rot6d: 6D rotation as (6,) array - first two rows of rotation matrix flattened
+
+ Returns:
+ Rotation matrix (3, 3)
+ """
+ rot6d = rot6d.reshape(2, 3)
+
+ # First two rows
+ row1 = rot6d[0]
+ row2 = rot6d[1]
+
+ # Normalize first row
+ row1 = row1 / np.linalg.norm(row1)
+
+ # Gram-Schmidt orthogonalization for second row
+ row2 = row2 - np.dot(row1, row2) * row1
+ row2 = row2 / np.linalg.norm(row2)
+
+ # Third row is cross product
+ row3 = np.cross(row1, row2)
+
+ # Construct rotation matrix
+ rotation_matrix = np.vstack([row1, row2, row3])
+
+ return rotation_matrix
+
+ @staticmethod
+ def _matrix_to_rot6d(rotation_matrix: np.ndarray) -> np.ndarray:
+ """
+ Convert rotation matrix to 6D rotation representation.
+
+ Args:
+ rotation_matrix: Rotation matrix (3, 3)
+
+ Returns:
+ 6D rotation - (6,) array (first two rows flattened)
+ """
+ return rotation_matrix[:2, :].flatten()
+
+ def _set_rotation(
+ self,
+ rotation: list | np.ndarray,
+ rotation_type: str,
+ rotation_order: str | None = None,
+ degrees: bool = True,
+ ):
+ """Internal method to set rotation from various representations"""
+ rotation = np.array(rotation)
+ rot_type = RotationType(rotation_type.lower())
+
+ if rot_type == RotationType.QUAT:
+ quat_order = QuatOrder(rotation_order.lower()) if rotation_order else QuatOrder.WXYZ
+ if quat_order == QuatOrder.WXYZ:
+ # scipy uses xyzw order, so convert
+ quat_xyzw = np.array([rotation[1], rotation[2], rotation[3], rotation[0]])
+ else:
+ quat_xyzw = rotation
+ self._rotation = Rotation.from_quat(quat_xyzw)
+
+ elif rot_type == RotationType.EULER:
+ euler_order = EulerOrder(rotation_order.lower()) if rotation_order else EulerOrder.XYZ
+ self._rotation = Rotation.from_euler(euler_order.value, rotation, degrees=degrees)
+
+ elif rot_type == RotationType.ROTVEC:
+ self._rotation = Rotation.from_rotvec(rotation)
+
+ elif rot_type == RotationType.MATRIX:
+ self._rotation = Rotation.from_matrix(rotation)
+
+ elif rot_type == RotationType.ROT6D:
+ rotation_matrix = self._rot6d_to_matrix(rotation)
+ self._rotation = Rotation.from_matrix(rotation_matrix)
+
+ else:
+ raise ValueError(f"Unknown rotation type: {rotation_type}")
+
+ # Invalidate cache
+ self._cache_valid = False
+
+ @property
+ def translation(self) -> np.ndarray:
+ """
+ Get translation vector.
+
+ Returns:
+ Translation array - shape (3,)
+ """
+ return self._translation.copy()
+
+ @property
+ def quat_wxyz(self) -> np.ndarray:
+ """Get rotation as quaternion in wxyz order (w, x, y, z)"""
+ return self.to_rotation("quat", "wxyz")
+
+ @property
+ def quat_xyzw(self) -> np.ndarray:
+ """Get rotation as quaternion in xyzw order (x, y, z, w)"""
+ return self.to_rotation("quat", "xyzw")
+
+ @property
+ def euler_xyz(self) -> np.ndarray:
+ """Get rotation as Euler angles in xyz order (degrees)"""
+ return self.to_rotation("euler", "xyz")
+
+ @property
+ def rotvec(self) -> np.ndarray:
+ """Get rotation as rotation vector (axis-angle)"""
+ return self.to_rotation("rotvec")
+
+ @property
+ def rotation_matrix(self) -> np.ndarray:
+ """Get rotation as 3x3 rotation matrix"""
+ return self.to_rotation("matrix")
+
+ @property
+ def rot6d(self) -> np.ndarray:
+ """Get rotation as 6D representation (first two rows of rotation matrix)"""
+ return self.to_rotation("rot6d")
+
+ @property
+ def xyz_rot6d(self) -> np.ndarray:
+ """Get pose as concatenated translation and 6D rotation (9,)"""
+ return np.concatenate([self._translation, self.rot6d])
+
+ @property
+ def xyz_rotvec(self) -> np.ndarray:
+ """Get pose as concatenated translation and rotation vector (6,)"""
+ return np.concatenate([self._translation, self.rotvec])
+
+ @property
+ def homogeneous(self) -> np.ndarray:
+ """
+ Get homogeneous transformation matrix.
+
+ Returns:
+ Homogeneous matrix - shape (4, 4)
+ """
+ if not self._cache_valid:
+ self._homogeneous_cache = self._compute_homogeneous()
+ self._cache_valid = True
+ assert self._homogeneous_cache is not None
+ return self._homogeneous_cache.copy()
+
+ def _compute_homogeneous(self) -> np.ndarray:
+ """Compute homogeneous transformation matrix"""
+ H = np.eye(4)
+ H[:3, :3] = self._rotation.as_matrix()
+ H[:3, 3] = self._translation
+ return H
+
+ def to_rotation(
+ self,
+ rotation_type: str,
+ rotation_order: str | None = None,
+ degrees: bool = True,
+ ) -> np.ndarray:
+ """
+ Get rotation in specified representation.
+
+ Args:
+ rotation_type: Desired type ("quat", "euler", "rotvec", "matrix", "rot6d")
+ rotation_order: Order/convention for the rotation type
+ degrees: For Euler angles, return in degrees (default True)
+
+ Returns:
+ Rotation in requested format
+ - Shape (4,) for quat
+ - Shape (3,) for euler/rotvec
+ - Shape (6,) for rot6d
+ - Shape (3, 3) for matrix
+ """
+ rot_type = RotationType(rotation_type.lower())
+
+ if rot_type == RotationType.ROT6D:
+ rotation_matrix = self._rotation.as_matrix()
+ return self._matrix_to_rot6d(rotation_matrix)
+
+ if rot_type == RotationType.QUAT:
+ quat_order = QuatOrder(rotation_order.lower()) if rotation_order else QuatOrder.WXYZ
+ quat_xyzw = self._rotation.as_quat()
+ if quat_order == QuatOrder.WXYZ:
+ return np.array([quat_xyzw[3], quat_xyzw[0], quat_xyzw[1], quat_xyzw[2]])
+ else:
+ return quat_xyzw
+
+ elif rot_type == RotationType.EULER:
+ euler_order = EulerOrder(rotation_order.lower()) if rotation_order else EulerOrder.XYZ
+ return self._rotation.as_euler(euler_order.value, degrees=degrees)
+
+ elif rot_type == RotationType.ROTVEC:
+ return self._rotation.as_rotvec()
+
+ elif rot_type == RotationType.MATRIX:
+ return self._rotation.as_matrix()
+
+ else:
+ raise ValueError(f"Unknown rotation type: {rotation_type}")
+
+ def to_homogeneous(self) -> np.ndarray:
+ """
+ Convert pose to homogeneous transformation matrix.
+ (Alias for the homogeneous property)
+
+ Returns:
+ Homogeneous matrix - shape (4, 4)
+ """
+ return self.homogeneous
+
+ def set_rotation(
+ self,
+ rotation: list | np.ndarray,
+ rotation_type: str,
+ rotation_order: str | None = None,
+ degrees: bool = True,
+ ):
+ """
+ Set rotation from specified representation.
+
+ Args:
+ rotation: Rotation data
+ rotation_type: Type of rotation ("quat", "euler", "rotvec", "matrix", "rot6d")
+ rotation_order: Order/convention for the rotation type
+ degrees: For Euler angles, whether the input is in degrees (default True)
+ """
+ self._set_rotation(rotation, rotation_type, rotation_order, degrees)
+
+ def _compute_relative(self, other): # type: ignore[override]
+ """
+ Compute relative transformation from other to self.
+
+ The result represents the transformation needed to go from other's frame to self's frame.
+ Mathematically: T_relative = T_other^{-1} * T_self
+
+ Args:
+ other: Reference end-effector pose
+
+ Returns:
+ EndEffectorPose representing the relative transformation
+ """
+ # Get homogeneous matrices
+ T_self = self.homogeneous
+ T_other = other.homogeneous
+
+ # Compute relative transformation: T_other^{-1} * T_self
+ T_relative = relative_transformation(T_other, T_self)
+
+ # Create new EndEffectorPose from relative transformation
+ return EndEffectorPose(homogeneous=T_relative)
+
+ @classmethod
+ def from_action_format(cls, data: np.ndarray, action_format: ActionFormat) -> "EndEffectorPose":
+ """
+ Create an EndEffectorPose from a flat array using the specified action format.
+
+ This is the inverse of the xyz_rot6d / xyz_rotvec / homogeneous properties.
+
+ Args:
+ data: Flat array whose layout depends on action_format.
+ action_format: One of ActionFormat.XYZ_ROT6D, XYZ_ROTVEC, or DEFAULT.
+
+ Returns:
+ EndEffectorPose instance.
+ """
+ if action_format == ActionFormat.XYZ_ROT6D:
+ return cls(translation=data[:3], rotation=data[3:], rotation_type="rot6d")
+ elif action_format == ActionFormat.XYZ_ROTVEC:
+ return cls(translation=data[:3], rotation=data[3:], rotation_type="rotvec")
+ elif action_format == ActionFormat.DEFAULT:
+ return cls(homogeneous=data.reshape(4, 4))
+ else:
+ raise ValueError(f"Unsupported ActionFormat: {action_format}")
+
+ def copy(self) -> "EndEffectorPose":
+ """
+ Create a deep copy of this end-effector pose.
+
+ Returns:
+ New EndEffectorPose instance with copied data
+ """
+ return EndEffectorPose(
+ translation=self._translation.copy(),
+ rotation=self._rotation.as_quat(),
+ rotation_type="quat",
+ rotation_order="xyzw",
+ )
+
+ def __repr__(self) -> str:
+ quat = self.to_rotation("quat", "wxyz")
+ return f"EndEffectorPose(translation={self.translation}, rotation_quat_wxyz={quat})"
+
+ def __eq__(self, other) -> bool:
+ if not isinstance(other, EndEffectorPose):
+ return False
+ return np.allclose(self._translation, other._translation) and np.allclose(
+ self._rotation.as_quat(), other._rotation.as_quat()
+ )
diff --git a/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py b/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py
new file mode 100644
index 00000000000..14e2b0b3a37
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/dataio/state_action/state_action_processor.py
@@ -0,0 +1,643 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""Unified state and action processor for robotics."""
+
+from copy import deepcopy
+
+import numpy as np
+
+from vllm_omni.diffusion.models.gr00t.configs.embodiment_configs import (
+ ActionFormat,
+ ActionRepresentation,
+ ActionType,
+ ModalityConfig,
+)
+from vllm_omni.diffusion.models.gr00t.dataio.state_action.action_chunking import (
+ EndEffectorActionChunk,
+ JointActionChunk,
+)
+from vllm_omni.diffusion.models.gr00t.dataio.state_action.pose import EndEffectorPose, JointPose
+from vllm_omni.diffusion.models.gr00t.dataio.utils import (
+ apply_sin_cos_encoding,
+ nested_dict_to_numpy,
+ normalize_values_meanstd,
+ normalize_values_minmax,
+ parse_modality_configs,
+ unnormalize_values_meanstd,
+ unnormalize_values_minmax,
+)
+
+
+class StateActionProcessor:
+ """
+ Unified processor for robot state and action data.
+
+ Handles:
+ - State normalization (min/max, mean/std, sin/cos encoding)
+ - Action normalization
+ - Absolute <-> Relative action representation conversion
+ - Action processing with state dependency
+ """
+
+ def __init__(
+ self,
+ modality_configs: dict[str, dict[str, ModalityConfig]],
+ statistics: (dict[str, dict[str, dict[str, dict[str, list[float]]]]] | None) = None,
+ use_percentiles: bool = False,
+ clip_outliers: bool = True,
+ apply_sincos_state_encoding: bool = False,
+ use_relative_action: bool = False,
+ ):
+ """
+ Initialize unified state and action processor.
+
+ Args:
+ modality_configs: Nested dict with structure:
+ {embodiment_tag: {modality: ModalityConfig}}
+ where modality in ["state", "action"]
+ Example: {"gr1": {"state": ModalityConfig(...), "action": ModalityConfig(...)}}
+ statistics: Optional nested dict with structure:
+ {embodiment_tag: {modality: {joint_group: {stat_type: values}}}}
+ where modality in ["state", "action", "relative_action"]
+ and stat_type in ["min", "max", "mean", "std", "q01", "q99"]
+ Example: {"gr1": {"state": {"left_arm": {"min": [...], "max": [...], ...}}}}
+ use_percentiles: Whether to use percentiles (q01/q99) instead of min/max
+ clip_outliers: Whether to clip normalized values to [-1, 1]
+ apply_sincos_state_encoding: Global flag to enable sin/cos encoding for states
+ """
+ self.modality_configs = parse_modality_configs(modality_configs)
+ self.statistics: dict[str, dict[str, dict[str, dict[str, list[float]]]]] = {}
+ self.use_percentiles = use_percentiles
+ self.clip_outliers = clip_outliers
+ self.apply_sincos_state_encoding = apply_sincos_state_encoding
+ self.use_relative_action = use_relative_action
+
+ # Normalization parameters computed from statistics
+ self.norm_params: dict[str, dict[str, dict[str, dict[str, np.ndarray]]]] = {}
+ # Format: norm_params[embodiment_tag][modality][joint_group][stat_type]
+ # where stat_type in ["min", "max", "mean", "std", "dim"]
+
+ if statistics is not None:
+ self.set_statistics(statistics)
+
+ def set_statistics(
+ self,
+ statistics: dict[str, dict[str, dict[str, dict[str, list[float]]]]],
+ override: bool = False,
+ ) -> None:
+ """
+ Set dataset statistics for normalization.
+
+ Args:
+ statistics: Nested dict with structure:
+ {embodiment_tag: {modality: {joint_group: {stat_type: values}}}}
+ """
+ for key in statistics:
+ if key not in self.statistics or override:
+ self.statistics[key] = deepcopy(statistics[key])
+ else:
+ print(f"Embodiment tag {key} already in statistics, skipping updating")
+ self._compute_normalization_parameters()
+
+ def _compute_normalization_parameters(self) -> None:
+ """Compute and cache normalization parameters from statistics for all embodiments and modalities."""
+ for embodiment_tag in self.statistics:
+ self.norm_params[embodiment_tag] = {}
+
+ for modality in ["state", "action"]:
+ if modality not in self.statistics[embodiment_tag]:
+ continue
+
+ self.norm_params[embodiment_tag][modality] = {}
+
+ for joint_group, stats in self.statistics[embodiment_tag][modality].items():
+ if self.use_percentiles:
+ min_vals = np.array(stats["q01"])
+ max_vals = np.array(stats["q99"])
+ else:
+ min_vals = np.array(stats["min"])
+ max_vals = np.array(stats["max"])
+
+ mean_vals = np.array(stats["mean"])
+ std_vals = np.array(stats["std"])
+
+ # Compute range, ensuring it's not zero
+ range_vals = max_vals - min_vals
+ range_vals = np.maximum(range_vals, 1e-8)
+
+ self.norm_params[embodiment_tag][modality][joint_group] = {
+ "min": min_vals,
+ "max": max_vals,
+ "dim": np.array(range_vals.shape[0]),
+ "mean": mean_vals,
+ "std": std_vals,
+ }
+
+ # Override absolute action stats with relative stats where specified
+ if "action" in self.modality_configs[embodiment_tag]:
+ modality_keys = self.modality_configs[embodiment_tag]["action"].modality_keys
+ action_configs = self.modality_configs[embodiment_tag]["action"].action_configs
+
+ if action_configs is not None:
+ for key, action_config in zip(modality_keys, action_configs):
+ if action_config.rep == ActionRepresentation.RELATIVE and self.use_relative_action:
+ if "relative_action" not in self.statistics[embodiment_tag]:
+ raise ValueError(
+ f"Relative action statistics required for embodiment '{embodiment_tag}' "
+ f"but 'relative_action' not found in statistics"
+ )
+ if key not in self.statistics[embodiment_tag]["relative_action"]:
+ raise ValueError(
+ f"Relative action statistics required for key '{key}' "
+ f"in embodiment '{embodiment_tag}' but not found"
+ )
+ action_dim = self.norm_params[embodiment_tag]["action"][key]["dim"]
+ self.norm_params[embodiment_tag]["action"][key] = nested_dict_to_numpy(
+ self.statistics[embodiment_tag]["relative_action"][key]
+ )
+ self.norm_params[embodiment_tag]["action"][key]["dim"] = action_dim
+
+ def apply_state(
+ self,
+ state: dict[str, np.ndarray],
+ embodiment_tag: str,
+ ) -> dict[str, np.ndarray]:
+ """
+ Apply state processing (normalization, encoding).
+
+ Args:
+ state: Dict mapping joint_group -> raw state values
+ Shape per group: (..., D) where D is state dimension
+ embodiment_tag: Embodiment identifier (e.g., "gr1")
+
+ Returns:
+ Dict mapping joint_group -> processed state values
+ - Sin/cos encoded groups: (..., 2*D)
+ - Other groups: (..., D)
+ """
+ normalized_values = {}
+ state = deepcopy(state) # Avoid modifying input
+
+ # Get sin/cos embedding keys if enabled
+ sin_cos_keys = None
+ if self.apply_sincos_state_encoding:
+ state_config = self.modality_configs[embodiment_tag].get("state")
+ if state_config and hasattr(state_config, "sin_cos_embedding_keys"):
+ sin_cos_keys = state_config.sin_cos_embedding_keys
+
+ for joint_group in self.modality_configs[embodiment_tag]["state"].modality_keys:
+ if joint_group not in state:
+ raise KeyError(f"Joint group '{joint_group}' not found in state dict for embodiment '{embodiment_tag}'")
+
+ # Strategy 1: Sin/cos encoding (doubles dimension)
+ if sin_cos_keys and joint_group in sin_cos_keys:
+ normalized_values[joint_group] = apply_sin_cos_encoding(state[joint_group])
+
+ # Strategy 2: Mean/std normalization
+ elif (
+ hasattr(
+ self.modality_configs[embodiment_tag]["state"],
+ "mean_std_embedding_keys",
+ )
+ and self.modality_configs[embodiment_tag]["state"].mean_std_embedding_keys
+ and joint_group in self.modality_configs[embodiment_tag]["state"].mean_std_embedding_keys
+ ):
+ params = self.norm_params[embodiment_tag]["state"][joint_group]
+ normalized = normalize_values_meanstd(state[joint_group], params)
+ normalized_values[joint_group] = normalized
+
+ # Strategy 3: Min/max normalization to [-1, 1]
+ else:
+ params = self.norm_params[embodiment_tag]["state"][joint_group]
+ normalized = normalize_values_minmax(state[joint_group], params)
+
+ if self.clip_outliers:
+ normalized = np.clip(normalized, -1.0, 1.0)
+
+ normalized_values[joint_group] = normalized
+
+ return normalized_values
+
+ def unapply_state(
+ self,
+ state: dict[str, np.ndarray],
+ embodiment_tag: str,
+ ) -> dict[str, np.ndarray]:
+ """
+ Reverse state processing (denormalization).
+
+ Args:
+ state: Dict mapping joint_group -> processed state values
+ embodiment_tag: Embodiment identifier
+
+ Returns:
+ Dict mapping joint_group -> raw state values
+
+ Raises:
+ ValueError: If attempting to reverse sin/cos encoding (not reversible)
+ """
+ unnormalized_values = {}
+
+ # Get sin/cos embedding keys if enabled
+ sin_cos_keys = None
+ if self.apply_sincos_state_encoding:
+ state_config = self.modality_configs[embodiment_tag].get("state")
+ if state_config and hasattr(state_config, "sin_cos_embedding_keys"):
+ sin_cos_keys = state_config.sin_cos_embedding_keys
+
+ for joint_group in self.modality_configs[embodiment_tag]["state"].modality_keys:
+ if joint_group not in state:
+ raise KeyError(f"Joint group '{joint_group}' not found in state dict for embodiment '{embodiment_tag}'")
+
+ # Sin/cos encoding is not reversible
+ if sin_cos_keys and joint_group in sin_cos_keys:
+ raise ValueError(
+ f"Cannot unapply sin/cos encoding for joint group '{joint_group}' "
+ f"in embodiment '{embodiment_tag}'. This transformation is not reversible."
+ )
+
+ # Reverse mean/std normalization
+ elif (
+ hasattr(
+ self.modality_configs[embodiment_tag]["state"],
+ "mean_std_embedding_keys",
+ )
+ and self.modality_configs[embodiment_tag]["state"].mean_std_embedding_keys
+ and joint_group in self.modality_configs[embodiment_tag]["state"].mean_std_embedding_keys
+ ):
+ params = self.norm_params[embodiment_tag]["state"][joint_group]
+ unnormalized = unnormalize_values_meanstd(state[joint_group], params)
+ unnormalized_values[joint_group] = unnormalized
+
+ # Reverse min/max normalization
+ else:
+ params = self.norm_params[embodiment_tag]["state"][joint_group]
+ unnormalized_values[joint_group] = unnormalize_values_minmax(state[joint_group], params)
+
+ return unnormalized_values
+
+ def apply_action(
+ self,
+ action: dict[str, np.ndarray],
+ embodiment_tag: str,
+ state: dict[str, np.ndarray] | None = None,
+ ) -> dict[str, np.ndarray]:
+ """
+ Apply action processing (absolute->relative conversion, normalization).
+
+ Processing order:
+ 1. Convert absolute actions to relative (if configured)
+ 2. Normalize actions
+
+ Args:
+ action: Dict mapping joint_group -> raw action values
+ Shape per group: (T, D) where T is action horizon, D is action dimension
+ embodiment_tag: Embodiment identifier
+ state: Optional dict mapping joint_group -> raw state values
+ Required if any action group uses ActionRepresentation.RELATIVE
+ Shape per group: (T_state, D) where last timestep is used as reference
+
+ Returns:
+ Dict mapping joint_group -> processed action values
+ Shape per group: (T, D)
+
+ Raises:
+ ValueError: If state is None but required for relative action conversion
+ """
+ action = deepcopy(action) # Avoid modifying input
+
+ # Step 1: Convert absolute actions to relative (if needed)
+ modality_keys = self.modality_configs[embodiment_tag]["action"].modality_keys
+ action_configs = self.modality_configs[embodiment_tag]["action"].action_configs
+
+ if action_configs is not None:
+ for key, action_config in zip(modality_keys, action_configs):
+ if action_config.rep == ActionRepresentation.RELATIVE and self.use_relative_action:
+ if state is None:
+ raise ValueError(
+ f"State dict required for relative action processing of key '{key}' "
+ f"in embodiment '{embodiment_tag}'"
+ )
+
+ # Determine which state key to use as reference
+ state_key = action_config.state_key if action_config.state_key else key
+
+ if state_key not in state:
+ raise KeyError(
+ f"Reference state key '{state_key}' not found in state dict "
+ f"for embodiment '{embodiment_tag}'"
+ )
+
+ # Use last state as reference frame
+ reference_state = state[state_key][-1]
+
+ # Convert absolute to relative
+ action[key] = self._convert_to_relative_action(
+ action=action[key],
+ reference_state=reference_state,
+ action_type=action_config.type,
+ action_format=action_config.format,
+ )
+
+ # Step 2: Normalize actions
+ normalized_values = {}
+ for joint_group in modality_keys:
+ if joint_group not in action:
+ raise KeyError(
+ f"Joint group '{joint_group}' not found in action dict for embodiment '{embodiment_tag}'"
+ )
+
+ params = self.norm_params[embodiment_tag]["action"][joint_group]
+ if (
+ self.modality_configs[embodiment_tag]["action"].mean_std_embedding_keys is not None
+ and joint_group in self.modality_configs[embodiment_tag]["action"].mean_std_embedding_keys
+ ):
+ normalized = normalize_values_meanstd(action[joint_group], params)
+ else:
+ normalized = normalize_values_minmax(action[joint_group], params)
+
+ if self.clip_outliers:
+ normalized = np.clip(normalized, -1.0, 1.0)
+
+ normalized_values[joint_group] = normalized
+
+ return normalized_values
+
+ def unapply_action(
+ self,
+ action: dict[str, np.ndarray],
+ embodiment_tag: str,
+ state: dict[str, np.ndarray] | None = None,
+ ) -> dict[str, np.ndarray]:
+ """
+ Reverse action processing (denormalization, relative->absolute conversion).
+
+ Processing order:
+ 1. Denormalize actions
+ 2. Convert relative actions to absolute (if configured)
+
+ Args:
+ action: Dict mapping joint_group -> processed action values
+ Shape per group: (T, D) or (B, T, D) for batched
+ embodiment_tag: Embodiment identifier
+ state: Optional dict mapping joint_group -> raw state values
+ Required if any action group uses ActionRepresentation.RELATIVE
+ Shape per group: (T_state, D) or (B, T_state, D) for batched
+
+ Returns:
+ Dict mapping joint_group -> raw absolute action values
+ Shape per group: (T, D) or (B, T, D) for batched
+
+ Raises:
+ ValueError: If state is None but required for relative->absolute conversion
+ """
+ # Step 1: Unnormalize actions
+ unnormalized_values = {}
+ modality_keys = self.modality_configs[embodiment_tag]["action"].modality_keys
+
+ for joint_group in modality_keys:
+ if joint_group not in action:
+ raise KeyError(
+ f"Joint group '{joint_group}' not found in action dict for embodiment '{embodiment_tag}'"
+ )
+
+ params = self.norm_params[embodiment_tag]["action"][joint_group]
+ group_values = action[joint_group]
+
+ if (
+ self.modality_configs[embodiment_tag]["action"].mean_std_embedding_keys is not None
+ and joint_group in self.modality_configs[embodiment_tag]["action"].mean_std_embedding_keys
+ ):
+ unnormalized = unnormalize_values_meanstd(group_values, params)
+ else:
+ unnormalized = unnormalize_values_minmax(group_values, params)
+
+ unnormalized_values[joint_group] = unnormalized
+
+ # Step 2: Convert relative actions to absolute (if needed)
+ action_configs = self.modality_configs[embodiment_tag]["action"].action_configs
+
+ if action_configs is not None:
+ for key, action_config in zip(modality_keys, action_configs):
+ if action_config.rep == ActionRepresentation.RELATIVE and self.use_relative_action:
+ if state is None:
+ raise ValueError(
+ f"State dict required for relative->absolute conversion of key '{key}' "
+ f"in embodiment '{embodiment_tag}'"
+ )
+
+ # Determine which state key to use as reference
+ state_key = action_config.state_key if action_config.state_key else key
+
+ if state_key not in state:
+ raise KeyError(
+ f"Reference state key '{state_key}' not found in state dict "
+ f"for embodiment '{embodiment_tag}'"
+ )
+
+ relative_action = unnormalized_values[key]
+
+ # Handle batched and unbatched cases
+ is_batched = relative_action.ndim == 3
+ if not is_batched:
+ assert relative_action.ndim == 2
+ reference_state = state[state_key]
+ if reference_state.ndim == 2:
+ reference_state = reference_state[None, :]
+ relative_action = relative_action[None, :]
+ else:
+ reference_state = state[state_key]
+ if reference_state.ndim == 2:
+ reference_state = reference_state[None, :]
+
+ # Convert batched relative actions to absolute
+ absolute_actions = []
+ for s, a in zip(reference_state, relative_action):
+ # Use last timestep of state as reference
+ absolute_action = self._convert_to_absolute_action(
+ action=a,
+ reference_state=s[-1],
+ action_type=action_config.type,
+ action_format=action_config.format,
+ )
+ absolute_actions.append(absolute_action)
+
+ if is_batched:
+ unnormalized_values[key] = np.stack(absolute_actions, axis=0)
+ else:
+ unnormalized_values[key] = absolute_actions[0]
+
+ return unnormalized_values
+
+ def apply(
+ self,
+ state: dict[str, np.ndarray],
+ action: dict[str, np.ndarray],
+ embodiment_tag: str,
+ ) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
+ """
+ Apply both state and action processing together.
+
+ Convenience method that processes state and action in one call,
+ automatically passing raw state to action processor for relative conversion.
+
+ Args:
+ state: Dict mapping joint_group -> raw state values
+ action: Dict mapping joint_group -> raw action values
+ embodiment_tag: Embodiment identifier
+
+ Returns:
+ Tuple of (processed_state, processed_action)
+ """
+ processed_state = self.apply_state(state, embodiment_tag)
+ if action:
+ processed_action = self.apply_action(action, embodiment_tag, state=state)
+ else:
+ processed_action = {}
+ return processed_state, processed_action
+
+ def unapply(
+ self,
+ state: dict[str, np.ndarray],
+ action: dict[str, np.ndarray],
+ embodiment_tag: str,
+ raw_state: dict[str, np.ndarray] | None = None,
+ ) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
+ """
+ Reverse both state and action processing together.
+
+ Args:
+ state: Dict mapping joint_group -> processed state values
+ action: Dict mapping joint_group -> processed action values
+ embodiment_tag: Embodiment identifier
+ raw_state: Optional dict of raw states for relative->absolute conversion
+ If None, will use unapplied state (but won't work for sin/cos encoded states)
+
+ Returns:
+ Tuple of (raw_state, raw_action)
+ """
+ # Unapply state first
+ try:
+ unapplied_state = self.unapply_state(state, embodiment_tag)
+ except ValueError as e:
+ if "sin/cos encoding" in str(e) and raw_state is None:
+ raise ValueError("Cannot unapply sin/cos encoded state. Please provide raw_state parameter.") from e
+ raise
+
+ # Use provided raw_state if available, otherwise use unapplied state
+ state_for_action = raw_state if raw_state is not None else unapplied_state
+
+ # Unapply action
+ unapplied_action = self.unapply_action(action, embodiment_tag, state=state_for_action)
+
+ return unapplied_state, unapplied_action
+
+ def get_state_dim(self, embodiment_tag: str, include_sincos_expansion: bool = False) -> int:
+ """
+ Get total state dimension after processing.
+
+ Args:
+ embodiment_tag: Embodiment identifier
+ include_sincos_expansion: If True, accounts for sin/cos encoding doubling dimensions
+
+ Returns:
+ Total state dimension across all joint groups
+ """
+ total_dim = 0
+ state_config = self.modality_configs[embodiment_tag]["state"]
+
+ # Get sin/cos embedding keys if enabled
+ sin_cos_keys = set()
+ if self.apply_sincos_state_encoding and hasattr(state_config, "sin_cos_embedding_keys"):
+ sin_cos_keys = set(state_config.sin_cos_embedding_keys)
+
+ for joint_group in state_config.modality_keys:
+ base_dim = self.norm_params[embodiment_tag]["state"][joint_group]["dim"].item()
+
+ # Sin/cos encoding doubles the dimension
+ if include_sincos_expansion and joint_group in sin_cos_keys:
+ total_dim += base_dim * 2
+ else:
+ total_dim += base_dim
+
+ return total_dim
+
+ def get_action_dim(self, embodiment_tag: str) -> int:
+ """
+ Get total action dimension.
+
+ Args:
+ embodiment_tag: Embodiment identifier
+
+ Returns:
+ Total action dimension across all joint groups
+ """
+ total_dim = 0
+ for joint_group in self.modality_configs[embodiment_tag]["action"].modality_keys:
+ total_dim += self.norm_params[embodiment_tag]["action"][joint_group]["dim"].item()
+ return total_dim
+
+ def _convert_to_relative_action(
+ self,
+ action: np.ndarray,
+ reference_state: np.ndarray,
+ action_type: ActionType,
+ action_format: ActionFormat,
+ ) -> np.ndarray:
+ """Convert absolute action to relative action using reference state."""
+ assert action.ndim == 2, f"Expected action shape (T, D), got {action.shape}"
+ assert reference_state.ndim == 1, f"Expected state shape (D,), got {reference_state.shape}"
+
+ if action_type == ActionType.EEF:
+ action_chunking = EndEffectorActionChunk.from_array(action, action_format)
+ reference_frame = EndEffectorPose.from_action_format(reference_state, action_format)
+
+ elif action_type == ActionType.NON_EEF:
+ action_chunking = JointActionChunk([JointPose(m) for m in action])
+ reference_frame = JointPose(reference_state)
+
+ else:
+ raise ValueError(f"Unknown ActionType: {action_type}")
+
+ relative_action_chunking = action_chunking.relative_chunking(reference_frame=reference_frame)
+ return relative_action_chunking.to(action_format)
+
+ def _convert_to_absolute_action(
+ self,
+ action: np.ndarray,
+ reference_state: np.ndarray,
+ action_type: ActionType,
+ action_format: ActionFormat,
+ ) -> np.ndarray:
+ """Convert relative action to absolute action using reference state."""
+ assert action.ndim == 2, f"Expected action shape (T, D), got {action.shape}"
+ assert reference_state.ndim == 1, f"Expected state shape (D,), got {reference_state.shape}"
+ assert reference_state.shape[0] == action.shape[1], (
+ f"State dim {reference_state.shape[0]} != action dim {action.shape[1]}"
+ )
+
+ if action_type == ActionType.EEF:
+ rel_action = EndEffectorActionChunk.from_array(action, action_format)
+ reference_frame = EndEffectorPose.from_action_format(reference_state, action_format)
+
+ elif action_type == ActionType.NON_EEF:
+ rel_action = JointActionChunk([JointPose(pose) for pose in action])
+ reference_frame = JointPose(reference_state)
+
+ else:
+ raise ValueError(f"Unknown ActionType: {action_type}")
+
+ abs_action = rel_action.to_absolute_chunking(reference_frame=reference_frame)
+ return abs_action.to(action_format)
+
+ def __str__(self) -> str:
+ return (
+ "StateActionProcessor("
+ f"modality_configs={self.modality_configs}, "
+ f"statistics={self.statistics}, "
+ f"use_percentiles={self.use_percentiles}, "
+ f"clip_outliers={self.clip_outliers}, "
+ f"apply_sincos_state_encoding={self.apply_sincos_state_encoding}, "
+ f"use_relative_action={self.use_relative_action})"
+ )
diff --git a/vllm_omni/diffusion/models/gr00t/dataio/types.py b/vllm_omni/diffusion/models/gr00t/dataio/types.py
new file mode 100644
index 00000000000..d4948a1c234
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/dataio/types.py
@@ -0,0 +1,112 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Any
+
+import numpy as np
+
+from vllm_omni.diffusion.models.gr00t.dataio.embodiment_tags import EmbodimentTag
+
+
+class MessageType(Enum):
+ START_OF_EPISODE = "start_of_episode"
+ END_OF_EPISODE = "end_of_episode"
+ EPISODE_STEP = "episode_step"
+ IMAGE = "image"
+ TEXT = "text"
+
+
+class ActionRepresentation(Enum):
+ RELATIVE = "relative"
+ DELTA = "delta"
+ ABSOLUTE = "absolute"
+
+
+class ActionType(Enum):
+ EEF = "eef"
+ NON_EEF = "non_eef"
+
+
+class ActionFormat(Enum):
+ DEFAULT = "default"
+ XYZ_ROT6D = "xyz+rot6d"
+ XYZ_ROTVEC = "xyz+rotvec"
+
+
+@dataclass
+class VLAStepData:
+ """
+ Represents a single step of VLA (Vision-Language-Action) data.
+
+ This is the core data structure returned by datasets, containing raw observation
+ and action data that will be processed by the SequenceVLAProcessor.
+ """
+
+ # Core data
+ images: dict[str, list[np.ndarray]] # view_name -> list[np.ndarray] (for temporal stacking)
+ states: dict[str, np.ndarray] # state_name -> np.ndarray (dim,) for single step or (horizon, dim) for trajectory
+ actions: dict[str, np.ndarray] # action_name -> np.ndarray (horizon, dim) for action chunk
+ masks: dict[str, list[np.ndarray]] | None = None # view_name -> list[np.ndarray] (H, W)
+ text: str | None = None # Optional task description or instruction
+ embodiment: EmbodimentTag = EmbodimentTag.NEW_EMBODIMENT # Optional embodiment tag for cross-embodiment training
+ is_demonstration: bool = (
+ False # Whether the step is a demonstration. If True, no loss should be computed for this step.
+ )
+
+ # Flexible metadata that can be extended by users
+ metadata: dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass
+class ActionConfig:
+ rep: ActionRepresentation
+ type: ActionType
+ format: ActionFormat
+ state_key: str | None = None
+
+
+@dataclass
+class ModalityConfig:
+ """Configuration for a modality defining how data should be sampled and loaded.
+
+ This class specifies which indices to sample relative to a base index and which
+ keys to load for a particular modality (e.g., video, state, action).
+ """
+
+ delta_indices: list[int]
+ """Delta indices to sample relative to the current index. The returned data will
+ correspond to the original data at a sampled base index + delta indices."""
+ modality_keys: list[str]
+ """The keys to load for the modality in the dataset."""
+ sin_cos_embedding_keys: list[str] | None = None
+ """Optional list of keys to apply sin/cos encoding. If None or empty, use
+ min/max normalization for all keys."""
+ mean_std_embedding_keys: list[str] | None = None
+ """Optional list of keys to apply mean/std normalization. If None or empty,
+ use min/max normalization for all keys."""
+ action_configs: list[ActionConfig] | None = None
+
+ def __post_init__(self):
+ """Validate fields and set default values."""
+ if self.delta_indices is None or not isinstance(self.delta_indices, list):
+ raise ValueError(f"delta_indices must be a non-None list, got {self.delta_indices!r}")
+ if self.modality_keys is None or not isinstance(self.modality_keys, list) or len(self.modality_keys) == 0:
+ raise ValueError(f"modality_keys must be a non-empty list, got {self.modality_keys!r}")
+ if self.action_configs is not None:
+ assert len(self.action_configs) == len(self.modality_keys), (
+ f"Number of action configs ({len(self.action_configs)}) must match "
+ f"number of modality keys ({len(self.modality_keys)})"
+ )
+ parsed_action_configs = []
+ for action_config in self.action_configs:
+ if isinstance(action_config, dict):
+ action_config = ActionConfig(
+ rep=ActionRepresentation[action_config["rep"]],
+ type=ActionType[action_config["type"]],
+ format=ActionFormat[action_config["format"]],
+ state_key=action_config.get("state_key", None),
+ )
+ parsed_action_configs.append(action_config)
+ self.action_configs = parsed_action_configs
diff --git a/vllm_omni/diffusion/models/gr00t/dataio/utils.py b/vllm_omni/diffusion/models/gr00t/dataio/utils.py
new file mode 100644
index 00000000000..8096d62396c
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/dataio/utils.py
@@ -0,0 +1,291 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from dataclasses import asdict, is_dataclass
+from enum import Enum
+from typing import Any
+
+import numpy as np
+
+from vllm_omni.diffusion.models.gr00t.configs.embodiment_configs import ModalityConfig
+
+
+def apply_sin_cos_encoding(values: np.ndarray) -> np.ndarray:
+ """Apply sin/cos encoding to values.
+
+ Args:
+ values: Array of shape (..., D) containing values to encode
+
+ Returns:
+ Array of shape (..., 2*D) with [sin, cos] concatenated
+
+ Note: This DOUBLES the dimension. For example:
+ Input: [v₁, v₂, v₃] with shape (..., 3)
+ Output: [sin(v₁), sin(v₂), sin(v₃), cos(v₁), cos(v₂), cos(v₃)] with shape (..., 6)
+ """
+ sin_values = np.sin(values)
+ cos_values = np.cos(values)
+ # Concatenate sin and cos: [sin(v1), sin(v2), ..., cos(v1), cos(v2), ...]
+ return np.concatenate([sin_values, cos_values], axis=-1)
+
+
+def nested_dict_to_numpy(data):
+ """
+ Recursively converts bottom-level list of lists to NumPy arrays.
+
+ Args:
+ data: A nested dictionary where bottom nodes are list of lists,
+ and parent nodes are strings (keys)
+
+ Returns:
+ The same dictionary structure with bottom-level lists converted to NumPy arrays
+
+ Example:
+ >>> data = {"a": {"b": [[0, 1], [2, 3]]}}
+ >>> result = nested_dict_to_numpy(data)
+ >>> print(result["a"]["b"])
+ [[0 1]
+ [2 3]]
+ """
+ if isinstance(data, dict):
+ return {key: nested_dict_to_numpy(value) for key, value in data.items()}
+ elif isinstance(data, list):
+ # Convert lists to numpy arrays
+ # NumPy will handle both 1D and 2D cases appropriately
+ return np.array(data)
+ else:
+ return data
+
+
+def normalize_values_minmax(values, params):
+ """
+ Normalize values using min-max normalization to [-1, 1] range.
+
+ Args:
+ values: Input values to normalize
+ - Shape: (T, D) or (B, T, D) where B is batch, T is time/step, D is feature dimension
+ - Can handle 2D or 3D arrays where last axis represents features
+ params: Dictionary with "min" and "max" keys
+ - params["min"]: Minimum values for normalization
+ * Case 1 - 1D bounds: Shape (D,) - same min/max for all steps
+ * Case 2 - 2D bounds: Shape (T, D) - different min/max per step
+ - params["max"]: Maximum values for normalization
+ * Case 1 - 1D bounds: Shape (D,) - same min/max for all steps
+ * Case 2 - 2D bounds: Shape (T, D) - different min/max per step
+ joint_group: Optional indexing for joint groups (legacy parameter)
+
+ Returns:
+ Normalized values in [-1, 1] range
+ - Same shape as input values: (T, D) or (B, T, D)
+ - Values are linearly mapped from [min, max] to [-1, 1]
+ - For features where min == max, normalized value is 0
+
+ Examples:
+ # 1D bounds - same normalization for all steps
+ values: (10, 5), params["min"]: (5,), params["max"]: (5,)
+
+ # 2D bounds - per-step normalization
+ values: (8, 4), params["min"]: (8, 4), params["max"]: (8, 4)
+ """
+ min_vals = params["min"]
+ max_vals = params["max"]
+ normalized = np.zeros_like(values)
+
+ mask = ~np.isclose(max_vals, min_vals)
+
+ normalized[..., mask] = (values[..., mask] - min_vals[..., mask]) / (max_vals[..., mask] - min_vals[..., mask])
+ normalized[..., mask] = 2 * normalized[..., mask] - 1
+
+ return normalized
+
+
+def unnormalize_values_minmax(normalized_values, params):
+ """
+ Min-max unnormalization from [-1, 1] range back to original range.
+
+ Args:
+ normalized_values: Normalized input values in [-1, 1] range
+ - Shape: (T, D) or (B, T, D) where B is batch, T is time/step, D is feature dimension
+ - Values outside [-1, 1] are automatically clipped
+ params: Dictionary with "min" and "max" keys
+ - params["min"]: Original minimum values used for normalization
+ * Case 1 - 1D bounds: Shape (D,) - same min/max for all steps
+ * Case 2 - 2D bounds: Shape (T, D) - different min/max per step
+ - params["max"]: Original maximum values used for normalization
+ * Case 1 - 1D bounds: Shape (D,) - same min/max for all steps
+ * Case 2 - 2D bounds: Shape (T, D) - different min/max per step
+
+ Returns:
+ Unnormalized values in original range [min, max]
+ - Same shape as input normalized_values: (T, D) or (B, T, D)
+ - Values are linearly mapped from [-1, 1] back to [min, max]
+ - Input values are clipped to [-1, 1] before unnormalization
+
+ Examples:
+ # 1D bounds - same unnormalization for all steps
+ normalized_values: (10, 5), params["min"]: (5,), params["max"]: (5,)
+
+ # 2D bounds - per-step unnormalization
+ normalized_values: (8, 4), params["min"]: (8, 4), params["max"]: (8, 4)
+ """
+
+ min_vals = params["min"]
+ max_vals = params["max"]
+ range_vals = max_vals - min_vals
+
+ # Unnormalize from [-1, 1]
+ unnormalized = (np.clip(normalized_values, -1.0, 1.0) + 1.0) / 2.0 * range_vals + min_vals
+ return unnormalized
+
+
+def normalize_values_meanstd(values, params):
+ """
+ Normalize values using mean-std (z-score) normalization.
+
+ Args:
+ values: Input values to normalize
+ - Shape: (T, D) or (B, T, D) where B is batch, T is time/step, D is feature dimension
+ - Can handle 2D or 3D arrays where last axis represents features
+ params: Dictionary with "mean" and "std" keys
+ - params["mean"]: Mean values for normalization
+ * Case 1 - 1D params: Shape (D,) - same mean for all steps
+ * Case 2 - 2D params: Shape (T, D) - different mean per step
+ - params["std"]: Standard deviation values for normalization
+ * Case 1 - 1D params: Shape (D,) - same std for all steps
+ * Case 2 - 2D params: Shape (T, D) - different std per step
+
+ Returns:
+ Normalized values using z-score normalization
+ - Same shape as input values: (T, D) or (B, T, D)
+ - Values are transformed as: (x - mean) / std
+ - For features where std == 0, normalized value equals original value
+
+ Examples:
+ # 1D params - same normalization for all steps
+ values: (10, 5), params["mean"]: (5,), params["std"]: (5,)
+
+ # 2D params - per-step normalization
+ values: (8, 4), params["mean"]: (8, 4), params["std"]: (8, 4)
+ """
+ mean_vals = params["mean"]
+ std_vals = params["std"]
+
+ # Create mask for non-zero standard deviations
+ mask = std_vals != 0
+
+ # Initialize normalized array
+ normalized = np.zeros_like(values)
+
+ # Normalize only features with non-zero std
+ normalized[..., mask] = (values[..., mask] - mean_vals[..., mask]) / std_vals[..., mask]
+
+ # Keep original values for zero-std features
+ normalized[..., ~mask] = values[..., ~mask]
+
+ return normalized
+
+
+def unnormalize_values_meanstd(normalized_values, params):
+ """
+ Mean-std unnormalization (reverse z-score normalization).
+
+ Args:
+ normalized_values: Normalized input values (z-scores)
+ - Shape: (T, D) or (B, T, D) where B is batch, T is time/step, D is feature dimension
+ - Can handle 2D or 3D arrays where last axis represents features
+ params: Dictionary with "mean" and "std" keys
+ - params["mean"]: Original mean values used for normalization
+ * Case 1 - 1D params: Shape (D,) - same mean for all steps
+ * Case 2 - 2D params: Shape (T, D) - different mean per step
+ - params["std"]: Original standard deviation values used for normalization
+ * Case 1 - 1D params: Shape (D,) - same std for all steps
+ * Case 2 - 2D params: Shape (T, D) - different std per step
+
+ Returns:
+ Unnormalized values in original scale
+ - Same shape as input normalized_values: (T, D) or (B, T, D)
+ - Values are transformed as: x * std + mean
+ - For features where std == 0, unnormalized value equals normalized value
+
+ Examples:
+ # 1D params - same unnormalization for all steps
+ normalized_values: (10, 5), params["mean"]: (5,), params["std"]: (5,)
+
+ # 2D params - per-step unnormalization
+ normalized_values: (8, 4), params["mean"]: (8, 4), params["std"]: (8, 4)
+ """
+ mean_vals = params["mean"]
+ std_vals = params["std"]
+
+ # Create mask for non-zero standard deviations
+ mask = std_vals != 0
+
+ # Initialize unnormalized array
+ unnormalized = np.zeros_like(normalized_values)
+
+ # Unnormalize only features with non-zero std
+ unnormalized[..., mask] = normalized_values[..., mask] * std_vals[..., mask] + mean_vals[..., mask]
+
+ # Keep normalized values for zero-std features
+ unnormalized[..., ~mask] = normalized_values[..., ~mask]
+
+ return unnormalized
+
+
+def to_json_serializable(obj: Any) -> Any:
+ """
+ Recursively convert dataclasses and numpy arrays to JSON-serializable format.
+
+ Args:
+ obj: Object to convert (can be dataclass, numpy array, dict, list, etc.)
+
+ Returns:
+ JSON-serializable representation of the object
+ """
+ if is_dataclass(obj) and not isinstance(obj, type):
+ # Convert dataclass to dict, then recursively process the dict
+ return to_json_serializable(asdict(obj))
+ elif isinstance(obj, np.ndarray):
+ # Convert numpy array to list
+ return obj.tolist()
+ elif isinstance(obj, np.integer):
+ # Convert numpy integers to Python int
+ return int(obj)
+ elif isinstance(obj, np.floating):
+ # Convert numpy floats to Python float
+ return float(obj)
+ elif isinstance(obj, np.bool_):
+ # Convert numpy bool to Python bool
+ return bool(obj)
+ elif isinstance(obj, dict):
+ # Recursively process dictionary values
+ return {key: to_json_serializable(value) for key, value in obj.items()}
+ elif isinstance(obj, (list, tuple)):
+ # Recursively process list/tuple elements
+ return [to_json_serializable(item) for item in obj]
+ elif isinstance(obj, set):
+ # Convert set to list
+ return [to_json_serializable(item) for item in obj]
+ elif isinstance(obj, (str, int, float, bool, type(None))):
+ # Already JSON-serializable
+ return obj
+ elif isinstance(obj, Enum):
+ return obj.name
+ else:
+ # For other types, try to convert to string as fallback
+ # You might want to handle specific types differently
+ return str(obj)
+
+
+def parse_modality_configs(
+ modality_configs: dict[str, dict[str, ModalityConfig]],
+) -> dict[str, dict[str, ModalityConfig]]:
+ parsed_modality_configs = {}
+ for embodiment_tag, modality_config in modality_configs.items():
+ parsed_modality_configs[embodiment_tag] = {}
+ for modality, config in modality_config.items():
+ if isinstance(config, dict):
+ parsed_modality_configs[embodiment_tag][modality] = ModalityConfig(**config)
+ else:
+ parsed_modality_configs[embodiment_tag][modality] = config
+ return parsed_modality_configs
diff --git a/vllm_omni/diffusion/models/gr00t/modeling/__init__.py b/vllm_omni/diffusion/models/gr00t/modeling/__init__.py
new file mode 100644
index 00000000000..7aecb405ff9
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/modeling/__init__.py
@@ -0,0 +1,7 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from vllm_omni.diffusion.models.gr00t.modeling.gr00t_n1d7 import Gr00tN1d7
+from vllm_omni.diffusion.models.gr00t.modeling.processing_gr00t_n1d7 import Gr00tN1d7Processor
+
+__all__ = ["Gr00tN1d7", "Gr00tN1d7Processor"]
diff --git a/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py
new file mode 100644
index 00000000000..4799b7cf615
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/modeling/gr00t_n1d7.py
@@ -0,0 +1,479 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import logging
+import os
+from typing import Any
+
+import torch
+from torch import nn
+from transformers import AutoConfig, AutoModel, PreTrainedModel
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.models.qwen3_vl.modeling_qwen3_vl import (
+ Qwen3VLForConditionalGeneration as _Qwen3VLForConditionalGeneration,
+)
+
+from vllm_omni.diffusion.models.gr00t.configs.gr00t_n1d7 import Gr00tN1d7Config
+from vllm_omni.diffusion.models.gr00t.modeling.modules.dit import AlternateVLDiT, DiT, SelfAttentionTransformer
+from vllm_omni.diffusion.models.gr00t.modeling.modules.embodiment_conditioned_mlp import (
+ CategorySpecificMLP,
+ MultiEmbodimentActionEncoder,
+)
+from vllm_omni.diffusion.models.gr00t.modeling.processing_gr00t_n1d7 import Gr00tN1d7DataCollator
+
+logger = logging.getLogger(__name__)
+
+
+class Gr00tN1d7ActionHead(nn.Module):
+ """Action head component for flow matching diffusion policy."""
+
+ supports_gradient_checkpointing = True
+
+ def __init__(self, config: Gr00tN1d7Config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.input_embedding_dim = config.input_embedding_dim
+
+ if config.use_alternate_vl_dit:
+ self.model = AlternateVLDiT(
+ **config.diffusion_model_cfg,
+ cross_attention_dim=config.backbone_embedding_dim,
+ attend_text_every_n_blocks=config.attend_text_every_n_blocks,
+ )
+ logger.info("Using AlternateVLDiT for diffusion model")
+ else:
+ self.model = DiT(
+ **config.diffusion_model_cfg,
+ cross_attention_dim=config.backbone_embedding_dim,
+ )
+ logger.info("Using DiT for diffusion model")
+ self.action_dim = config.max_action_dim
+ self.action_horizon = config.action_horizon
+ self.num_inference_timesteps = config.num_inference_timesteps
+
+ self.state_encoder = CategorySpecificMLP(
+ num_categories=config.max_num_embodiments,
+ input_dim=config.max_state_dim * config.state_history_length,
+ hidden_dim=self.hidden_size,
+ output_dim=self.input_embedding_dim,
+ )
+ self.action_encoder = MultiEmbodimentActionEncoder(
+ action_dim=self.action_dim,
+ hidden_size=self.input_embedding_dim,
+ num_embodiments=config.max_num_embodiments,
+ )
+ self.action_decoder = CategorySpecificMLP(
+ num_categories=config.max_num_embodiments,
+ input_dim=self.hidden_size,
+ hidden_dim=self.hidden_size,
+ output_dim=self.action_dim,
+ )
+
+ self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
+
+ vl_self_attention_cfg = config.vl_self_attention_cfg
+ if vl_self_attention_cfg and vl_self_attention_cfg.get("num_layers", 0) > 0:
+ self.vl_self_attention = SelfAttentionTransformer(**vl_self_attention_cfg)
+ else:
+ self.vl_self_attention = nn.Identity()
+
+ if config.add_pos_embed:
+ self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
+ nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
+
+ self.num_timestep_buckets = config.num_timestep_buckets
+
+ def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
+ backbone_features = backbone_output["backbone_features"]
+ backbone_features = self.vlln(backbone_features)
+ backbone_features = self.vl_self_attention(backbone_features)
+ backbone_output["backbone_features"] = backbone_features
+ return backbone_output
+
+ def _encode_features(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
+ """
+ Encode features for the action head.
+
+ Args:
+ backbone_output: Output from the backbone model containing:
+ - backbone_features: [B, seq_len, backbone_embedding_dim]
+ - backbone_attention_mask: [B, seq_len]
+ action_input: Input containing:
+ - state: [B, state_history_length, max_state_dim]
+ - embodiment_id: [B] (embodiment IDs)
+
+ Returns:
+ BatchFeature containing:
+ - backbone_features: [B, seq_len, backbone_embedding_dim]
+ - state_features: [B, 1, input_embedding_dim]
+ """
+ backbone_output = self.process_backbone_output(backbone_output)
+
+ vl_embeds = backbone_output.backbone_features
+ embodiment_id = action_input.embodiment_id
+
+ state = action_input.state
+ current_T = state.shape[1]
+ assert current_T == self.config.state_history_length, "current_T != state_history_length"
+ # [B, state_history_length, max_state_dim] -> [B, 1, state_history_length * max_state_dim]
+ state = state.view(state.shape[0], 1, -1)
+
+ state_features = self.state_encoder(state, embodiment_id)
+
+ return BatchFeature(data={"backbone_features": vl_embeds, "state_features": state_features})
+
+ @torch.no_grad()
+ def get_action_with_features(
+ self,
+ backbone_features: torch.Tensor,
+ state_features: torch.Tensor,
+ embodiment_id: torch.Tensor,
+ backbone_output: BatchFeature,
+ action_input: BatchFeature,
+ options: dict[str, Any] | None = None,
+ ) -> BatchFeature:
+ """
+ Generate actions using the flow matching diffusion process.
+
+ Args:
+ backbone_features: [B, seq_len, backbone_embedding_dim]
+ state_features: [B, state_horizon, input_embedding_dim]
+ embodiment_id: [B] (embodiment IDs)
+ backbone_output: Output from the backbone model
+ """
+ vl_embeds = backbone_features
+
+ batch_size = vl_embeds.shape[0]
+ device = vl_embeds.device
+ _seed_env = os.environ.get("GR00T_NOISE_SEED")
+ if _seed_env is not None:
+ _gen = torch.Generator(device=device).manual_seed(int(_seed_env))
+ actions = torch.randn(
+ size=(batch_size, self.config.action_horizon, self.action_dim),
+ dtype=vl_embeds.dtype,
+ device=device,
+ generator=_gen,
+ )
+ else:
+ actions = torch.randn(
+ size=(batch_size, self.config.action_horizon, self.action_dim),
+ dtype=vl_embeds.dtype,
+ device=device,
+ )
+
+ dt = 1.0 / self.num_inference_timesteps
+ vel_strength = torch.ones_like(actions)
+
+ if "action" in action_input:
+ # If action in input when doing get action, it means we want to use RTC.
+ # action_horizon is the action horizon of the input action.
+ # rtc_overlap_steps is the number of steps to overlap with the previous action chunks.
+ # rtc_frozen_steps is the policy inference latency expressed as frozen action steps.
+ # rtc_ramp_rate is the rate of the ramp of denoising the actions.
+ assert options is not None, "options is not None"
+ assert "action_horizon" in options, "action_horizon is not in options"
+ assert "rtc_overlap_steps" in options, "rtc_overlap_steps is not in options"
+ assert "rtc_frozen_steps" in options, "rtc_frozen_steps is not in options"
+ assert "rtc_ramp_rate" in options, "rtc_ramp_rate is not in options"
+
+ action_horizon_before_padding = options["action_horizon"]
+
+ # Use previous action instead of pure noise to do inpainting
+ actions[:, : options["rtc_overlap_steps"], :] = action_input["action"][
+ :,
+ action_horizon_before_padding - options["rtc_overlap_steps"] : action_horizon_before_padding,
+ :,
+ ]
+ vel_strength[:, : options["rtc_frozen_steps"], :] = 0.0
+ # NOTE: use an exponential ramp strength to set the remaining unfrozen rtc_steps
+ intermediate_steps = options["rtc_overlap_steps"] - options["rtc_frozen_steps"]
+ # Create exponential ramp from 0 to 1 over intermediate steps
+ t = torch.linspace(0.0, 1.0, intermediate_steps + 2, device=device)
+ ramp = 1 - torch.exp(-options["rtc_ramp_rate"] * t)
+ ramp = ramp / ramp[-1].clamp_min(1e-8) # normalize to [0,1]
+ ramp = ramp[1:-1] # we will only take the middle part of the ramp, ignore the 0.0 and 1.0
+ # Apply ramp to the intermediate steps [batch, intermediate_steps, action_dim]
+ vel_strength[
+ :,
+ options["rtc_frozen_steps"] : options["rtc_overlap_steps"],
+ :,
+ ] = ramp[None, :, None].to(device)
+
+ # Run denoising steps.
+ for t in range(self.num_inference_timesteps):
+ t_discretized = t * self.num_timestep_buckets // self.num_inference_timesteps
+
+ # Embed noised action trajectory.
+ timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
+ action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
+ # Add position embedding.
+ if self.config.add_pos_embed:
+ pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
+ pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
+ action_features = action_features + pos_embs
+
+ # Join vision, language, state and action embedding along sequence dimension.
+ sa_embs = torch.cat((state_features, action_features), dim=1)
+
+ # Run model forward.
+ if self.config.use_alternate_vl_dit:
+ model_output = self.model(
+ hidden_states=sa_embs,
+ encoder_hidden_states=vl_embeds,
+ timestep=timesteps_tensor,
+ image_mask=backbone_output.image_mask,
+ backbone_attention_mask=backbone_output.backbone_attention_mask,
+ )
+ else:
+ model_output = self.model(
+ hidden_states=sa_embs,
+ encoder_hidden_states=vl_embeds,
+ timestep=timesteps_tensor,
+ )
+ pred = self.action_decoder(model_output, embodiment_id)
+
+ pred_velocity = pred[:, -self.action_horizon :]
+
+ # Update actions using euler integration.
+ actions = actions + dt * pred_velocity * vel_strength
+
+ return BatchFeature(
+ data={
+ "action_pred": actions,
+ "backbone_features": vl_embeds,
+ "state_features": state_features,
+ }
+ )
+
+ @torch.no_grad()
+ def get_action(
+ self,
+ backbone_output: BatchFeature,
+ action_input: BatchFeature,
+ options: dict[str, Any] | None = None,
+ ) -> BatchFeature:
+ """
+ Generate actions using the flow matching diffusion process.
+
+ Args:
+ backbone_output: Output from the backbone model containing:
+ - backbone_features: [B, seq_len, backbone_embedding_dim]
+ - backbone_attention_mask: [B, seq_len]
+ action_input: Input containing:
+ - state: [B, state_dim]
+ - embodiment_id: [B] (embodiment IDs)
+
+ Returns:
+ BatchFeature containing:
+ - action_pred: [B, action_horizon, action_dim] predicted actions
+ """
+ features = self._encode_features(backbone_output, action_input)
+ return self.get_action_with_features(
+ backbone_features=features.backbone_features,
+ state_features=features.state_features,
+ embodiment_id=action_input.embodiment_id,
+ backbone_output=backbone_output,
+ action_input=action_input,
+ options=options,
+ )
+
+ @property
+ def device(self):
+ return next(iter(self.parameters())).device
+
+ @property
+ def dtype(self):
+ return next(iter(self.parameters())).dtype
+
+ def prepare_input(self, batch: dict) -> BatchFeature:
+ """Prepare input batch for the action head."""
+ return BatchFeature(data=batch)
+
+
+class _Qwen3VLBackbone(nn.Module):
+ """GR00T adapter around the shared Qwen3-VL implementation."""
+
+ def __init__(
+ self,
+ model_name: str,
+ select_layer: int,
+ backbone_embedding_dim: int,
+ load_bf16: bool,
+ transformers_loading_kwargs: dict[str, Any] | None = None,
+ ):
+ super().__init__()
+ backbone_config = AutoConfig.from_pretrained(
+ model_name, **(transformers_loading_kwargs or {"trust_remote_code": True})
+ )
+ backbone_config.text_config._attn_implementation = "sdpa"
+ backbone_config.vision_config._attn_implementation = "sdpa"
+ self.model = _Qwen3VLForConditionalGeneration(backbone_config).eval()
+ if load_bf16:
+ self.model.to(dtype=torch.bfloat16)
+
+ target_layers = select_layer if select_layer >= 0 else len(self.model.model.language_model.layers)
+ while len(self.model.model.language_model.layers) > target_layers:
+ self.model.model.language_model.layers.pop(-1)
+
+ def set_frozen_modules_to_eval_mode(self) -> None:
+ self.eval()
+
+ def prepare_input(self, batch: dict) -> BatchFeature:
+ return BatchFeature(data=batch)
+
+ def forward(self, vl_input: BatchFeature) -> BatchFeature:
+ self.set_frozen_modules_to_eval_mode()
+ keys_to_use = [
+ "input_ids",
+ "attention_mask",
+ "pixel_values",
+ "image_grid_thw",
+ "mm_token_type_ids",
+ ]
+ vl_input = {key: vl_input[key] for key in keys_to_use if key in vl_input}
+ # GR00T was trained against the **pre-final-RMSNorm** output of the last
+ # decoder layer (this was `outer.hidden_states[-1]` under HF 4.57.x semantics,
+ # where the @check_model_inputs tie-logic skipped overwriting because
+ # Qwen3VLCausalLMOutputWithPast has no `last_hidden_state` field).
+ # HF >=5.x changed the output_hidden_states tuple plumbing so that
+ # `hidden_states[-1]` is post-norm, which is a different tensor and breaks
+ # the trained action head. Capture the pre-norm tensor via a hook on the
+ # final RMSNorm regardless of transformers version.
+ _captured: list[torch.Tensor] = []
+
+ def _pre_norm_hook(_module, args, _out):
+ _captured.append(args[0])
+
+ norm = self.model.model.language_model.norm
+ handle = norm.register_forward_hook(_pre_norm_hook)
+ try:
+ self.model.model(**vl_input, return_dict=True)
+ finally:
+ handle.remove()
+ if not _captured:
+ raise RuntimeError("Failed to capture pre-norm hidden states from Qwen3-VL backbone")
+ backbone_features = _captured[-1]
+ image_mask = vl_input["input_ids"] == self.model.config.image_token_id
+ attention_mask = vl_input["attention_mask"] == 1
+ return BatchFeature(
+ data={
+ "backbone_features": backbone_features,
+ "backbone_attention_mask": attention_mask,
+ "image_mask": image_mask,
+ }
+ )
+
+
+def get_backbone_cls(config: Gr00tN1d7Config):
+ if "nvidia/Cosmos-Reason2" in config.model_name:
+ return _Qwen3VLBackbone
+ raise ValueError(f"Unsupported model name: {config.model_name}")
+
+
+class Gr00tN1d7(PreTrainedModel):
+ """Gr00tN1d7: VLA model with Cosmos-Reason2-2B (Qwen3-VL) backbone."""
+
+ config_class = Gr00tN1d7Config
+ supports_gradient_checkpointing = True
+ _tp_plan = {}
+
+ @property
+ def all_tied_weights_keys(self) -> dict[str, Any]:
+ return {}
+
+ def __init__(
+ self,
+ config: Gr00tN1d7Config,
+ transformers_loading_kwargs: dict | None = None,
+ ):
+ super().__init__(config)
+ if transformers_loading_kwargs is None:
+ transformers_loading_kwargs = {"trust_remote_code": True}
+ self.config = config
+
+ backbone_cls = get_backbone_cls(config)
+ self.backbone = backbone_cls(
+ model_name=config.model_name,
+ select_layer=config.select_layer,
+ backbone_embedding_dim=config.backbone_embedding_dim,
+ load_bf16=config.load_bf16,
+ transformers_loading_kwargs=transformers_loading_kwargs,
+ )
+
+ # Initialize action head
+ self.action_head = Gr00tN1d7ActionHead(config)
+
+ self.collator = Gr00tN1d7DataCollator(
+ model_name=config.model_name,
+ model_type=config.backbone_model_type,
+ transformers_loading_kwargs=transformers_loading_kwargs,
+ )
+
+ def prepare_input(self, inputs: dict) -> tuple[BatchFeature, BatchFeature]:
+ """Prepare inputs for backbone and action head."""
+
+ # NOTE -- currently the eval code doesn't use collator, so we need to add it here
+ if "vlm_content" in inputs:
+ # Fix for n_envs > 1: Process all environments' VLM content, not just the first
+ vlm_content_list = inputs["vlm_content"]
+ # Ensure vlm_content_list is always a list for consistent processing
+ if not isinstance(vlm_content_list, list):
+ vlm_content_list = [vlm_content_list]
+
+ # Process all VLM contents through the collator
+ prep = self.collator([{"vlm_content": vlm} for vlm in vlm_content_list])["inputs"]
+ inputs.pop("vlm_content")
+ inputs.update(prep)
+
+ backbone_inputs = self.backbone.prepare_input(inputs)
+ action_inputs = self.action_head.prepare_input(inputs)
+
+ backbone_inputs = backbone_inputs.to(device=self.device, dtype=self.dtype)
+ action_inputs = action_inputs.to(device=self.device, dtype=self.dtype)
+
+ return backbone_inputs, action_inputs
+
+ def forward(self, inputs: dict) -> BatchFeature:
+ """
+ Forward pass through the complete model.
+
+ Args:
+ inputs: Dictionary containing:
+ - Action inputs (state, action, embodiment_id, etc.)
+
+ Returns:
+ BatchFeature containing loss and other outputs
+ """
+ # Prepare inputs for backbone and action head
+ backbone_inputs, action_inputs = self.prepare_input(inputs)
+ backbone_outputs = self.backbone(backbone_inputs)
+ action_outputs = self.action_head(backbone_outputs, action_inputs)
+
+ return action_outputs
+
+ def get_action(self, inputs: dict, options: dict[str, Any] | None = None) -> BatchFeature:
+ """
+ Generate actions using the complete model.
+ """
+ # Prepare inputs for backbone and action head
+ backbone_inputs, action_inputs = self.prepare_input(inputs)
+
+ # Forward through backbone
+ backbone_outputs = self.backbone(backbone_inputs)
+ action_outputs = self.action_head.get_action(backbone_outputs, action_inputs, options)
+
+ return action_outputs
+
+ @property
+ def device(self):
+ return next(iter(self.parameters())).device
+
+ @property
+ def dtype(self):
+ return next(iter(self.parameters())).dtype
+
+
+# Register the model with HuggingFace
+AutoConfig.register("Gr00tN1d7", Gr00tN1d7Config)
+AutoModel.register(Gr00tN1d7Config, Gr00tN1d7)
diff --git a/vllm_omni/diffusion/models/gr00t/modeling/modules/__init__.py b/vllm_omni/diffusion/models/gr00t/modeling/modules/__init__.py
new file mode 100644
index 00000000000..208f01a7cb5
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/modeling/modules/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
diff --git a/vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py b/vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py
new file mode 100644
index 00000000000..54fbe770d0d
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/modeling/modules/dit.py
@@ -0,0 +1,344 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import torch
+import torch.nn.functional as F
+from diffusers import ConfigMixin, ModelMixin
+from diffusers.configuration_utils import register_to_config
+from diffusers.models.attention import Attention, FeedForward
+from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
+from torch import nn
+
+
+class AdaLayerNorm(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-5,
+ chunk_dim: int = 0,
+ ):
+ super().__init__()
+ self.chunk_dim = chunk_dim
+ output_dim = embedding_dim * 2
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, output_dim)
+ self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ temb: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ temb = self.linear(self.silu(temb))
+ scale, shift = temb.chunk(2, dim=1)
+ x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
+ return x
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: int | None = None,
+ activation_fn: str = "geglu",
+ attention_bias: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm",
+ norm_eps: float = 1e-5,
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ positional_embeddings: str | None = None,
+ num_positional_embeddings: int | None = None,
+ ff_inner_dim: int | None = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+ self.norm_type = norm_type
+
+ if positional_embeddings and (num_positional_embeddings is None):
+ raise ValueError(
+ "If `positional_embedding` type is defined, `num_position_embeddings` must also be defined."
+ )
+
+ if positional_embeddings == "sinusoidal":
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+ else:
+ self.pos_embed = None
+
+ if norm_type == "ada_norm":
+ self.norm1 = AdaLayerNorm(dim)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ out_bias=attention_out_bias,
+ )
+
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+ self.final_dropout = nn.Dropout(dropout) if final_dropout else None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ encoder_hidden_states: torch.Tensor | None = None,
+ encoder_attention_mask: torch.Tensor | None = None,
+ temb: torch.LongTensor | None = None,
+ ) -> torch.Tensor:
+ if self.norm_type == "ada_norm":
+ norm_hidden_states = self.norm1(hidden_states, temb)
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=(encoder_attention_mask if encoder_hidden_states is not None else attention_mask),
+ )
+ if self.final_dropout:
+ attn_output = self.final_dropout(attn_output)
+
+ hidden_states = attn_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ norm_hidden_states = self.norm3(hidden_states)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+ return hidden_states
+
+
+class TimestepEncoder(nn.Module):
+ def __init__(self, embedding_dim: int):
+ super().__init__()
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
+ dtype = next(self.parameters()).dtype
+ return self.timestep_embedder(self.time_proj(timesteps).to(dtype))
+
+
+class DiT(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 8,
+ attention_head_dim: int = 64,
+ output_dim: int = 26,
+ num_layers: int = 12,
+ dropout: float = 0.1,
+ attention_bias: bool = True,
+ activation_fn: str = "gelu-approximate",
+ num_embeds_ada_norm: int | None = 1000,
+ upcast_attention: bool = False,
+ norm_type: str = "ada_norm",
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-5,
+ max_num_positional_embeddings: int = 512,
+ compute_dtype=torch.float32,
+ final_dropout: bool = True,
+ positional_embeddings: str | None = "sinusoidal",
+ interleave_self_attention=False,
+ cross_attention_dim: int | None = None,
+ ):
+ super().__init__()
+
+ self.inner_dim = num_attention_heads * attention_head_dim
+ self.gradient_checkpointing = False
+ self.timestep_encoder = TimestepEncoder(embedding_dim=self.inner_dim)
+
+ all_blocks = []
+ for idx in range(num_layers):
+ use_self_attn = idx % 2 == 1 and interleave_self_attention
+ curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None
+ all_blocks.append(
+ BasicTransformerBlock(
+ self.inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ positional_embeddings=positional_embeddings,
+ num_positional_embeddings=max_num_positional_embeddings,
+ final_dropout=final_dropout,
+ cross_attention_dim=curr_cross_attention_dim,
+ )
+ )
+ self.transformer_blocks = nn.ModuleList(all_blocks)
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
+ self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: torch.LongTensor | None = None,
+ encoder_attention_mask: torch.Tensor | None = None,
+ return_all_hidden_states: bool = False,
+ ):
+ temb = self.timestep_encoder(timestep)
+ hidden_states = hidden_states.contiguous()
+ encoder_hidden_states = encoder_hidden_states.contiguous()
+ all_hidden_states = [hidden_states]
+
+ for idx, block in enumerate(self.transformer_blocks):
+ if idx % 2 == 1 and self.config.interleave_self_attention:
+ hidden_states = block(hidden_states, temb=temb)
+ else:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ temb=temb,
+ )
+ all_hidden_states.append(hidden_states)
+
+ shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ if return_all_hidden_states:
+ return self.proj_out_2(hidden_states), all_hidden_states
+ return self.proj_out_2(hidden_states)
+
+
+class AlternateVLDiT(DiT):
+ def __init__(self, *args, attend_text_every_n_blocks: int = 2, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.attend_text_every_n_blocks = attend_text_every_n_blocks
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: torch.LongTensor | None = None,
+ encoder_attention_mask: torch.Tensor | None = None,
+ return_all_hidden_states: bool = False,
+ image_mask: torch.Tensor | None = None,
+ backbone_attention_mask: torch.Tensor | None = None,
+ ):
+ assert image_mask is not None, "Image mask is required"
+ assert self.config.interleave_self_attention, "Interleave self attention must be enabled"
+
+ temb = self.timestep_encoder(timestep)
+ hidden_states = hidden_states.contiguous()
+ encoder_hidden_states = encoder_hidden_states.contiguous()
+
+ image_attention_mask = image_mask & backbone_attention_mask
+ non_image_attention_mask = (~image_mask) & backbone_attention_mask
+
+ all_hidden_states = [hidden_states]
+ for idx, block in enumerate(self.transformer_blocks):
+ if idx % 2 == 1:
+ hidden_states = block(hidden_states, temb=temb)
+ else:
+ curr_encoder_attention_mask = (
+ non_image_attention_mask
+ if idx % (2 * self.attend_text_every_n_blocks) == 0
+ else image_attention_mask
+ )
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=curr_encoder_attention_mask,
+ temb=temb,
+ )
+ all_hidden_states.append(hidden_states)
+
+ shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ if return_all_hidden_states:
+ return self.proj_out_2(hidden_states), all_hidden_states
+ return self.proj_out_2(hidden_states)
+
+
+class SelfAttentionTransformer(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 8,
+ attention_head_dim: int = 64,
+ output_dim: int = 26,
+ num_layers: int = 12,
+ dropout: float = 0.1,
+ attention_bias: bool = True,
+ activation_fn: str = "gelu-approximate",
+ num_embeds_ada_norm: int | None = 1000,
+ upcast_attention: bool = False,
+ max_num_positional_embeddings: int = 512,
+ compute_dtype=torch.float32,
+ final_dropout: bool = True,
+ positional_embeddings: str | None = "sinusoidal",
+ interleave_self_attention=False,
+ ):
+ super().__init__()
+
+ self.inner_dim = num_attention_heads * attention_head_dim
+ self.gradient_checkpointing = False
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ self.inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ upcast_attention=upcast_attention,
+ positional_embeddings=positional_embeddings,
+ num_positional_embeddings=max_num_positional_embeddings,
+ final_dropout=final_dropout,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ return_all_hidden_states: bool = False,
+ ):
+ hidden_states = hidden_states.contiguous()
+ all_hidden_states = [hidden_states]
+ for block in self.transformer_blocks:
+ hidden_states = block(hidden_states)
+ all_hidden_states.append(hidden_states)
+ if return_all_hidden_states:
+ return hidden_states, all_hidden_states
+ return hidden_states
diff --git a/vllm_omni/diffusion/models/gr00t/modeling/modules/embodiment_conditioned_mlp.py b/vllm_omni/diffusion/models/gr00t/modeling/modules/embodiment_conditioned_mlp.py
new file mode 100644
index 00000000000..17f0ed661f7
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/modeling/modules/embodiment_conditioned_mlp.py
@@ -0,0 +1,169 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from vllm_omni.diffusion.utils.flow_matching import SinusoidalPositionalEncoding, swish
+
+
+class CategorySpecificLinear(nn.Module):
+ """Linear layer with category-specific weights and biases for multi-embodiment support."""
+
+ def __init__(self, num_categories, input_dim, hidden_dim):
+ super().__init__()
+ self.num_categories = num_categories
+ # For each category, we have separate weights and biases.
+ self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
+ self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
+
+ def forward(self, x, cat_ids):
+ """
+ Args:
+ x: [B, T, input_dim] input tensor
+ cat_ids: [B] category/embodiment IDs
+ Returns:
+ [B, T, hidden_dim] output tensor
+ """
+ selected_W = self.W[cat_ids]
+ selected_b = self.b[cat_ids]
+ return torch.bmm(x, selected_W) + selected_b.unsqueeze(1)
+
+ def expand_action_dimension(self, old_action_dim, new_action_dim, expand_input=False, expand_output=False):
+ """
+ Safely expand action dimension with explicit targeting.
+
+ Args:
+ old_action_dim: Original action dimension
+ new_action_dim: New (larger) action dimension
+ expand_input: Whether to expand input dimension (dim=1)
+ expand_output: Whether to expand output dimension (dim=2)
+ """
+ if new_action_dim <= old_action_dim:
+ raise ValueError(f"New action dim {new_action_dim} must be larger than old action dim {old_action_dim}")
+
+ # Expand input dimension (dim=1) only if explicitly requested AND dimensions match
+ if expand_input and self.W.shape[1] == old_action_dim:
+ repeat_times = new_action_dim // old_action_dim
+ remainder = new_action_dim % old_action_dim
+
+ new_W_parts = [self.W] * repeat_times
+ if remainder > 0:
+ new_W_parts.append(self.W[:, :remainder, :])
+
+ new_W = torch.cat(new_W_parts, dim=1)
+ self.W = nn.Parameter(new_W)
+
+ # Expand output dimension (dim=2) only if explicitly requested AND dimensions match
+ if expand_output and self.W.shape[2] == old_action_dim:
+ repeat_times = new_action_dim // old_action_dim
+ remainder = new_action_dim % old_action_dim
+
+ new_W_parts = [self.W] * repeat_times
+ if remainder > 0:
+ new_W_parts.append(self.W[:, :, :remainder])
+
+ new_W = torch.cat(new_W_parts, dim=2)
+ self.W = nn.Parameter(new_W)
+
+ # Expand bias for output dimension
+ if self.b.shape[1] == old_action_dim:
+ new_b_parts = [self.b] * repeat_times
+ if remainder > 0:
+ new_b_parts.append(self.b[:, :remainder])
+
+ new_b = torch.cat(new_b_parts, dim=1)
+ self.b = nn.Parameter(new_b)
+
+
+class CategorySpecificMLP(nn.Module):
+ """Two-layer MLP with category-specific weights for multi-embodiment support."""
+
+ def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
+ super().__init__()
+ self.num_categories = num_categories
+ self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
+ self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
+
+ def forward(self, x, cat_ids):
+ """
+ Args:
+ x: [B, T, input_dim] input tensor
+ cat_ids: [B] category/embodiment IDs
+ Returns:
+ [B, T, output_dim] output tensor
+ """
+ hidden = F.relu(self.layer1(x, cat_ids))
+ return self.layer2(hidden, cat_ids)
+
+ def expand_action_dimension(self, old_action_dim, new_action_dim):
+ """
+ Expand action dimension by copying weights from existing dimensions.
+
+ Args:
+ old_action_dim: Original action dimension
+ new_action_dim: New (larger) action dimension
+ """
+ # self.layer1 does not take action_dim as input, so no expansion needed
+ self.layer2.expand_action_dimension(old_action_dim, new_action_dim, expand_input=False, expand_output=True)
+
+
+class MultiEmbodimentActionEncoder(nn.Module):
+ """Action encoder with multi-embodiment support and sinusoidal positional encoding."""
+
+ def __init__(self, action_dim, hidden_size, num_embodiments):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.num_embodiments = num_embodiments
+
+ # W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
+ self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
+ self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
+ self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
+ self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
+
+ def forward(self, actions, timesteps, cat_ids):
+ """
+ Args:
+ actions: [B, T, action_dim] action tensor
+ timesteps: [B,] timesteps - a single scalar per batch item
+ cat_ids: [B,] category/embodiment IDs
+ Returns:
+ [B, T, hidden_size] encoded action features
+ """
+ B, T, _ = actions.shape
+
+ # 1) Expand each batch's single scalar time 'tau' across all T steps
+ # so that shape => (B, T)
+ # e.g. if timesteps is (B,), replicate across T
+ if timesteps.dim() == 1 and timesteps.shape[0] == B:
+ # shape (B,) => (B,T)
+ timesteps = timesteps.unsqueeze(1).expand(-1, T)
+ else:
+ raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.")
+
+ # 2) Standard action MLP step for shape => (B, T, w)
+ a_emb = self.W1(actions, cat_ids)
+
+ # 3) Get the sinusoidal encoding (B, T, w)
+ tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
+
+ # 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
+ x = torch.cat([a_emb, tau_emb], dim=-1)
+ x = swish(self.W2(x, cat_ids))
+
+ # 5) Finally W3 => (B, T, w)
+ x = self.W3(x, cat_ids)
+ return x
+
+ def expand_action_dimension(self, old_action_dim, new_action_dim):
+ """
+ Expand action dimension by copying weights from existing dimensions.
+
+ Args:
+ old_action_dim: Original action dimension
+ new_action_dim: New (larger) action dimension
+ """
+ # Only W1 takes action_dim as input, so only expand its input dimension
+ self.W1.expand_action_dimension(old_action_dim, new_action_dim, expand_input=True, expand_output=False)
diff --git a/vllm_omni/diffusion/models/gr00t/modeling/modules/flowmatching_modules.py b/vllm_omni/diffusion/models/gr00t/modeling/modules/flowmatching_modules.py
new file mode 100644
index 00000000000..86b6c4b344b
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/modeling/modules/flowmatching_modules.py
@@ -0,0 +1,54 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import torch
+from torch import nn
+
+from vllm_omni.diffusion.utils.flow_matching import SinusoidalPositionalEncoding, swish
+
+__all__ = ["ActionEncoder", "SinusoidalPositionalEncoding", "swish"]
+
+
+class ActionEncoder(nn.Module):
+ def __init__(self, action_dim, hidden_size):
+ super().__init__()
+ self.hidden_size = hidden_size
+
+ # W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
+ self.W1 = nn.Linear(action_dim, hidden_size) # (d -> w)
+ self.W2 = nn.Linear(2 * hidden_size, hidden_size) # (2w -> w)
+ self.W3 = nn.Linear(hidden_size, hidden_size) # (w -> w)
+
+ self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
+
+ def forward(self, actions, timesteps):
+ """
+ actions: shape (B, T, action_dim)
+ timesteps: shape (B,) -- a single scalar per batch item
+ returns: shape (B, T, hidden_size)
+ """
+ B, T, _ = actions.shape
+
+ # 1) Expand each batch's single scalar time 'tau' across all T steps
+ # so that shape => (B, T)
+ # e.g. if timesteps is (B,), replicate across T
+ if timesteps.dim() == 1 and timesteps.shape[0] == B:
+ # shape (B,) => (B,T)
+ timesteps = timesteps.unsqueeze(1).expand(-1, T)
+ else:
+ raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.")
+
+ # 2) Standard action MLP step for shape => (B, T, w)
+ a_emb = self.W1(actions)
+
+ # 3) Get the sinusoidal encoding (B, T, w)
+ tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
+
+ # 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
+ x = torch.cat([a_emb, tau_emb], dim=-1)
+ x = swish(self.W2(x))
+
+ # 5) Finally W3 => (B, T, w)
+ x = self.W3(x)
+
+ return x
diff --git a/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py b/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py
new file mode 100644
index 00000000000..d26ed38be86
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/modeling/processing_gr00t_n1d7.py
@@ -0,0 +1,659 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import json
+import os
+import re
+import warnings
+from copy import deepcopy
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import torch
+import torchvision.transforms.v2 as transforms
+from PIL import Image
+from transformers import AutoProcessor, ProcessorMixin, Qwen3VLProcessor
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.utils import cached_file
+from vllm.logger import init_logger
+
+from vllm_omni.diffusion.models.gr00t.configs.embodiment_configs import ModalityConfig
+from vllm_omni.diffusion.models.gr00t.dataio.embodiment_tags import EmbodimentTag
+from vllm_omni.diffusion.models.gr00t.dataio.state_action.state_action_processor import StateActionProcessor
+from vllm_omni.diffusion.models.gr00t.dataio.utils import parse_modality_configs, to_json_serializable
+
+
+class LetterBoxTransform:
+ """Pad image to square dimensions by adding black bars to the smaller side."""
+
+ def __call__(self, img: torch.Tensor) -> torch.Tensor:
+ import math
+
+ *leading_dims, c, h, w = img.shape
+ if h == w:
+ return img
+ max_dim = max(h, w)
+ pad_h, pad_w = max_dim - h, max_dim - w
+ pad_top, pad_left = pad_h // 2, pad_w // 2
+ pad_bottom, pad_right = pad_h - pad_top, pad_w - pad_left
+ if leading_dims:
+ batch_size = math.prod(leading_dims)
+ img_r = img.reshape(batch_size, c, h, w)
+ padded = transforms.functional.pad(img_r, padding=[pad_left, pad_top, pad_right, pad_bottom], fill=0)
+ return padded.reshape(leading_dims + [c, max_dim, max_dim])
+ return transforms.functional.pad(img, padding=[pad_left, pad_top, pad_right, pad_bottom], fill=0)
+
+
+def _build_eval_image_transform(
+ image_target_size: list[int],
+ image_crop_size: list[int],
+) -> transforms.Compose:
+ """Deterministic eval/inference image transform (letterbox → resize → centercrop → resize)."""
+ return transforms.Compose(
+ [
+ transforms.ToImage(),
+ LetterBoxTransform(),
+ transforms.Resize(size=image_target_size),
+ transforms.CenterCrop(size=image_crop_size),
+ transforms.Resize(size=image_target_size),
+ ]
+ )
+
+
+logger = init_logger(__name__)
+
+# Suppress protobuf deprecation warnings
+warnings.filterwarnings("ignore", category=DeprecationWarning, module="google.protobuf")
+
+EMBODIMENT_TAG_TO_PROJECTOR_INDEX = {
+ # Pretrain embodiment ids
+ "oxe_droid_relative_eef_relative_joint": 24,
+ "xdof_relative_eef_relative_joint": 27,
+ "xdof_relative_eef_relative_joint_subtask": 27,
+ "real_g1_relative_eef_relative_joints": 25,
+ "real_r1_pro_sharpa_relative_eef": 26,
+ "real_r1_pro_sharpa_relative_eef_human": 26,
+ "real_r1_pro_sharpa_relative_eef_maxinsights": 26,
+ "real_r1_pro_sharpa_relative_eef_mecka": 26,
+ # Posttrain embodiment ids
+ "unitree_g1_full_body_with_waist_height_nav_cmd": 25,
+ "unitree_g1_sonic": 11,
+ "simpler_env_google": 0,
+ "simpler_env_widowx": 1,
+ "libero_sim": 2,
+ "new_embodiment": 10,
+}
+
+QWEN3_VL_2B_PROCESSOR = "Qwen/Qwen3-VL-2B-Instruct"
+
+
+def build_processor(model_name: str, transformers_loading_kwargs: dict) -> Qwen3VLProcessor:
+ if model_name == "nvidia/Cosmos-Reason2-2B":
+ # Cosmos-Reason2-2B lacks a Qwen3VLProcessor; fall back to upstream artifacts.
+ logger.warning_once(
+ "Substituting processor from %s because %s does not ship one. "
+ "If you fine-tune Cosmos-Reason2-2B's tokenizer/image processor, "
+ "load the processor explicitly instead of relying on this fallback.",
+ QWEN3_VL_2B_PROCESSOR,
+ model_name,
+ )
+ model_name = QWEN3_VL_2B_PROCESSOR
+ return Qwen3VLProcessor.from_pretrained(model_name, **transformers_loading_kwargs)
+
+
+class Gr00tN1d7DataCollator:
+ def __init__(
+ self,
+ model_name: str,
+ model_type: str = "qwen",
+ transformers_loading_kwargs: dict = {},
+ ):
+ self.processor = build_processor(model_name, transformers_loading_kwargs)
+ self.processor.tokenizer.padding_side = "left"
+ self.model_type = model_type
+ self.model_name = model_name
+
+ def __call__(self, features: list[dict[str, Any]]) -> BatchFeature:
+ batch = {}
+ keys = list(set().union(*(elem.keys() for elem in features)))
+
+ for key in keys:
+ values = [elem[key] for elem in features if key in elem]
+ if key == "vlm_content":
+ text_list = []
+ image_inputs = []
+ for v in values:
+ curr_text_list = [v["text"]]
+
+ text_list += curr_text_list
+ curr_image_inputs = v["images"]
+ image_inputs += curr_image_inputs
+
+ vlm_inputs = self.processor(
+ text=text_list,
+ images=image_inputs,
+ return_tensors="pt",
+ padding=True,
+ )
+ for k, v in vlm_inputs.items():
+ batch[k] = v
+ elif key in (
+ "pixel_values",
+ "image_grid_thw",
+ "attention_mask",
+ "input_ids",
+ ):
+ raise Exception("Not implemented")
+ else:
+ batch[key] = torch.from_numpy(np.stack(values))
+ return BatchFeature(data={"inputs": batch})
+
+ def __str__(self):
+ return f"Gr00tN1d7DataCollator(model_name={self.model_name}, model_type={self.model_type})"
+
+
+class Gr00tN1d7Processor(ProcessorMixin):
+ data_collator_class = Gr00tN1d7DataCollator
+
+ def __init__(
+ self,
+ modality_configs: dict[str, dict[str, ModalityConfig]],
+ statistics: (dict[str, dict[str, dict[str, dict[str, list[float]]]]] | None) = None,
+ use_percentiles: bool = False,
+ clip_outliers: bool = True,
+ image_crop_size: list[int] = None,
+ image_target_size: list[int] = None,
+ shortest_image_edge: int = 256,
+ crop_fraction: float = 0.95,
+ random_rotation_angle: int | None = None,
+ color_jitter_params: dict[str, float] | None = None,
+ formalize_language: bool = True,
+ model_name: str = "nvidia/Cosmos-Reason2-2B",
+ model_type: str = "qwen",
+ max_state_dim: int = 29,
+ max_action_dim: int = 29,
+ max_action_horizon: int = 50,
+ apply_sincos_state_encoding: bool = False,
+ use_relative_action: bool = False,
+ embodiment_id_mapping: dict[str, int] | None = None,
+ transformers_loading_kwargs: dict = {"trust_remote_code": True},
+ exclude_state: bool = False,
+ # Normalization
+ use_mean_std: bool = False,
+ **kwargs, # absorb deprecated training-only keys from saved processor_config.json
+ ):
+ if kwargs:
+ logger.debug("Gr00tN1d7Processor: ignoring unknown keys: %s", list(kwargs))
+ self.modality_configs = parse_modality_configs(modality_configs)
+
+ # Initialize StateActionProcessor for state/action normalization
+ self.state_action_processor = StateActionProcessor(
+ modality_configs=modality_configs,
+ statistics=statistics,
+ use_percentiles=use_percentiles,
+ clip_outliers=clip_outliers,
+ apply_sincos_state_encoding=apply_sincos_state_encoding,
+ use_relative_action=use_relative_action,
+ )
+
+ self.use_percentiles = use_percentiles
+ self.use_mean_std = use_mean_std
+ self.clip_outliers = clip_outliers
+ self.apply_sincos_state_encoding = apply_sincos_state_encoding
+ self.use_relative_action = use_relative_action
+
+ self.exclude_state = exclude_state
+
+ self.formalize_language = formalize_language
+ self.model_name = model_name
+ self.model_type = model_type
+
+ self.max_state_dim = max_state_dim
+ self.max_action_dim = max_action_dim
+ self.max_action_horizon = max_action_horizon
+
+ self.image_crop_size = image_crop_size
+ self.image_target_size = image_target_size
+ self.random_rotation_angle = random_rotation_angle
+ self.color_jitter_params = color_jitter_params
+ self.processor = build_processor(model_name, transformers_loading_kwargs)
+ self.processor.tokenizer.padding_side = "left"
+ self.embodiment_id_mapping = embodiment_id_mapping or EMBODIMENT_TAG_TO_PROJECTOR_INDEX
+ for k, v in EMBODIMENT_TAG_TO_PROJECTOR_INDEX.items():
+ if k not in self.embodiment_id_mapping:
+ self.embodiment_id_mapping[k] = v
+ self.shortest_image_edge = shortest_image_edge
+ self.crop_fraction = crop_fraction
+
+ self.statistics: dict[str, dict[str, dict[str, dict[str, list[float]]]]] = {}
+
+ # Eval/inference image transform
+ self.eval_image_transform = _build_eval_image_transform(
+ image_target_size,
+ image_crop_size,
+ )
+ self._collator = self.data_collator_class(
+ model_name=model_name,
+ model_type=model_type,
+ transformers_loading_kwargs=transformers_loading_kwargs,
+ )
+
+ @property
+ def collator(self):
+ return self._collator
+
+ def set_statistics(
+ self,
+ statistics: dict[str, dict[str, dict[str, dict[str, list[float]]]]],
+ override: bool = False,
+ ) -> None:
+ """Set dataset statistics for normalization."""
+ for key in statistics:
+ if key not in self.statistics or override:
+ if override:
+ print(f"Overriding statistics for {key}")
+ self.statistics[key] = deepcopy(statistics[key])
+ else:
+ print(f"Embodiment tag {key} already in statistics, skipping updating")
+
+ self.state_action_processor.set_statistics(statistics, override=override)
+
+ # Compute action dimensions for convenience
+ self.action_dim = {}
+ for embodiment_tag in self.state_action_processor.statistics:
+ self.action_dim[embodiment_tag] = self.state_action_processor.get_action_dim(embodiment_tag)
+
+ def decode_action(
+ self,
+ action: np.ndarray,
+ embodiment_tag: EmbodimentTag,
+ state: dict[str, np.ndarray] | None = None,
+ ):
+ """Undo action normalization and convert relative actions to absolute."""
+ # Split concatenated action into joint groups
+ out_dict = {}
+ start_idx = 0
+ joint_groups = self.modality_configs[embodiment_tag.value]["action"].modality_keys
+ action_horizon = len(self.modality_configs[embodiment_tag.value]["action"].delta_indices)
+ for key in joint_groups:
+ joint_dim = self.state_action_processor.norm_params[embodiment_tag.value]["action"][key]["dim"].item()
+ out_dict[key] = action[..., :action_horizon, start_idx : start_idx + joint_dim]
+ start_idx += joint_dim
+
+ # Use StateActionProcessor to unnormalize and convert to absolute
+ return self.state_action_processor.unapply_action(out_dict, embodiment_tag.value, state=state)
+
+ def unapply(
+ self,
+ action: np.ndarray,
+ embodiment_tag: EmbodimentTag,
+ state: dict[str, np.ndarray] | None = None,
+ prev_action: dict[str, np.ndarray] | None = None,
+ ) -> dict[str, np.ndarray]:
+ """Undo action normalization and convert relative to absolute.
+
+ Args:
+ action: Normalized action array of shape (..., action_horizon, action_dim)
+ embodiment_tag: Embodiment tag
+ state: State observations with "state." prefixed keys (for relative actions)
+ prev_action: Unused (kept for API compatibility)
+
+ Returns:
+ Dict mapping "action." to unnormalized (absolute) action arrays.
+ """
+ out_dict = {}
+ start_idx = 0
+ joint_groups = self.modality_configs[embodiment_tag.value]["action"].modality_keys
+ action_horizon = len(self.modality_configs[embodiment_tag.value]["action"].delta_indices)
+ for key in joint_groups:
+ joint_dim = self.state_action_processor.norm_params[embodiment_tag.value]["action"][key]["dim"].item()
+ out_dict[key] = action[..., :action_horizon, start_idx : start_idx + joint_dim]
+ start_idx += joint_dim
+
+ # Strip "state." prefix for StateActionProcessor
+ stripped_state = None
+ if state is not None:
+ stripped_state = {k.replace("state.", ""): v for k, v in state.items()}
+
+ result = self.state_action_processor.unapply_action(out_dict, embodiment_tag.value, state=stripped_state)
+ return {f"action.{key}": value for key, value in result.items()}
+
+ def process_observation(self, observation: dict[str, Any], embodiment_tag: EmbodimentTag):
+ """Process batched observation tensors for inference.
+
+ Args:
+ observation: Dict with keys like "video.", "state.", ""
+ Video values expected as numpy arrays of shape (B, T, H, W, C).
+ embodiment_tag: Embodiment tag identifying the robot configuration.
+
+ Returns:
+ BatchFeature with tokenized VLM inputs, state, embodiment_id, and action_mask.
+ """
+ modality_config = self.modality_configs[embodiment_tag.value]
+ transformed_observation = {}
+
+ # Normalize states
+ state_keys = modality_config["state"].modality_keys
+ state_data = {key: observation[f"state.{key}"] for key in state_keys}
+ exclude_state = self.exclude_state or getattr(modality_config["state"], "exclude_state", False)
+ if exclude_state:
+ normalized_states = torch.cat(
+ [torch.from_numpy(np.zeros_like(state_data[key])) for key in state_keys], dim=-1
+ )
+ else:
+ norm_state_dict = self.state_action_processor.apply_state(
+ state=state_data, embodiment_tag=embodiment_tag.value
+ )
+ normalized_states = torch.cat([torch.from_numpy(norm_state_dict[key]) for key in state_keys], dim=-1)
+
+ assert normalized_states.shape[1] <= self.max_state_dim, (
+ f"State dimension {normalized_states.shape[1]} exceeds max_state_dim {self.max_state_dim}"
+ )
+ padding_shape = (
+ *normalized_states.shape[:-1],
+ self.max_state_dim - normalized_states.shape[-1],
+ )
+ normalized_states = torch.cat([normalized_states, torch.zeros(padding_shape)], dim=-1)
+ transformed_observation["state"] = normalized_states
+
+ image_keys = modality_config["video"].modality_keys
+ images_dict = {view: torch.from_numpy(observation[f"video.{view}"]) for view in image_keys}
+ images = torch.stack([images_dict[view] for view in image_keys], dim=2) # (B, T, V, H, W, C)
+ assert images.ndim == 6
+ B, T, V, img_H, img_W, img_C = images.shape
+
+ images_perm = images.permute(0, 1, 2, 5, 3, 4).reshape(B, T * V, img_C, img_H, img_W) # (B,T*V,C,H,W)
+ transformed_images = self.eval_image_transform(images_perm).numpy()
+
+ language_key = modality_config["language"].modality_keys[0]
+ language = [
+ re.sub(r"[^\w\s]", "", lang.lower()) if self.formalize_language else lang
+ for lang in observation[language_key]
+ ]
+
+ texts, all_images = [], []
+ for i in range(B):
+ vlm_inputs = self._apply_vlm_processing(transformed_images[i], language[i])
+ vc = vlm_inputs["vlm_content"]
+ texts.append(vc["text"])
+ all_images.extend(vc["images"])
+ tokenized = self.processor(text=texts, images=all_images, return_tensors="pt", padding=True)
+ for k, v in tokenized.items():
+ transformed_observation[k] = v
+
+ embodiment_id = torch.ones(B, dtype=torch.int32) * self.embodiment_id_mapping[embodiment_tag.value]
+ transformed_observation["embodiment_id"] = embodiment_id
+
+ action_config = modality_config["action"]
+ action_horizon = len(action_config.delta_indices)
+ assert action_horizon <= self.max_action_horizon, (
+ f"Action horizon {action_horizon} (from delta_indices) exceeds"
+ f" max_action_horizon {self.max_action_horizon}. Increase model config"
+ f" action_horizon to >= {action_horizon}."
+ )
+ action_mask = torch.zeros((B, self.max_action_horizon), dtype=torch.float32)
+ if action_horizon > 0:
+ action_mask[:, :action_horizon] = 1.0
+ transformed_observation["action_mask"] = action_mask
+
+ return BatchFeature(transformed_observation)
+
+ def _apply_vlm_processing(self, images: np.ndarray, language: str) -> BatchFeature:
+ pil_images = [Image.fromarray(np.transpose(v, (1, 2, 0))) for v in images]
+ conversation = [
+ {
+ "role": "user",
+ "content": [
+ *[{"type": "image", "image": img} for img in pil_images],
+ {"type": "text", "text": language},
+ ],
+ }
+ ]
+
+ text = self.processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False)
+ return {
+ "vlm_content": {
+ "text": text,
+ "images": pil_images,
+ "conversation": conversation,
+ }
+ }
+
+ def __call__(
+ self,
+ messages: list[dict[str, Any]],
+ ):
+ assert len(messages) == 1
+ content = messages[0]["content"]
+ embodiment_tag = content.embodiment
+ action_data = content.actions
+ state_data = content.states
+
+ norm_state_dict, normalized_actions = self.state_action_processor.apply(
+ state=state_data,
+ action=action_data,
+ embodiment_tag=embodiment_tag.value,
+ )
+
+ if normalized_actions:
+ action_keys = self.modality_configs[embodiment_tag.value]["action"].modality_keys
+ normalized_actions = torch.cat(
+ [torch.from_numpy(normalized_actions[key]) for key in action_keys],
+ dim=-1,
+ ) # (t, d)
+ action_dim = normalized_actions.shape[1]
+ normalized_actions = torch.cat(
+ [
+ normalized_actions,
+ torch.zeros(
+ normalized_actions.shape[0],
+ self.max_action_dim - normalized_actions.shape[1],
+ ),
+ ],
+ dim=-1,
+ ) # (t, max_action_dim)
+ action_horizon = normalized_actions.shape[0]
+ assert action_horizon <= self.max_action_horizon, (
+ f"Action sequence length {action_horizon} exceeds max_action_horizon"
+ f" {self.max_action_horizon}. Increase model config action_horizon to"
+ f" >= {action_horizon}."
+ )
+ normalized_actions = torch.cat(
+ [
+ normalized_actions,
+ torch.zeros(
+ self.max_action_horizon - normalized_actions.shape[0],
+ self.max_action_dim,
+ ),
+ ],
+ dim=0,
+ ) # (max_action_horizon, max_action_dim)
+ action_mask = torch.ones_like(normalized_actions)
+ action_mask[action_horizon:] = 0
+ action_mask[:, action_dim:] = 0
+ else:
+ normalized_actions = None
+ action_mask = None
+
+ state_keys = self.modality_configs[embodiment_tag.value]["state"].modality_keys
+ exclude_state = self.exclude_state or getattr(
+ self.modality_configs[embodiment_tag.value]["state"], "exclude_state", False
+ )
+ if exclude_state:
+ normalized_states = torch.cat(
+ [torch.from_numpy(np.zeros_like(state_data[key])) for key in state_keys], dim=-1
+ )
+ else:
+ normalized_states = torch.cat([torch.from_numpy(norm_state_dict[key]) for key in state_keys], dim=-1)
+ normalized_states = torch.cat(
+ [
+ normalized_states,
+ torch.zeros(
+ normalized_states.shape[0],
+ self.max_state_dim - normalized_states.shape[1],
+ ),
+ ],
+ dim=-1,
+ )
+
+ image_transform = self.eval_image_transform
+ image_keys = self.modality_configs[embodiment_tag.value]["video"].modality_keys
+
+ if self.formalize_language:
+ language = content.text.lower()
+ language = re.sub(r"[^\w\s]", "", language)
+ else:
+ language = content.text
+
+ vlm_inputs = self._get_vlm_inputs(
+ image_keys=image_keys,
+ images=content.images,
+ masks=content.masks,
+ image_transform=image_transform,
+ language=language,
+ )
+
+ transformed_inputs = {
+ "state": normalized_states.to(torch.get_default_dtype()),
+ }
+ if normalized_actions is not None:
+ transformed_inputs["action"] = normalized_actions.to(torch.get_default_dtype())
+ # Add VLM inputs
+ transformed_inputs.update(vlm_inputs)
+ if action_mask is not None:
+ transformed_inputs["action_mask"] = action_mask
+ transformed_inputs["embodiment_id"] = self.embodiment_id_mapping[embodiment_tag.value]
+ return transformed_inputs
+
+ def _get_vlm_inputs(
+ self,
+ image_keys: list[str],
+ images: list[Image.Image],
+ masks: dict[str, list[np.ndarray]] | None,
+ image_transform: transforms.Compose,
+ language: str,
+ ):
+ temporal_stacked_images = {}
+
+ if masks is not None:
+ raise ValueError("Mask-based transforms are not supported at inference.")
+ for view in image_keys:
+ assert view in images, f"{view} not in {images}"
+ temporal_stacked_images[view] = torch.stack([image_transform(img) for img in images[view]]) # (T, C, H, W)
+
+ for k, v in temporal_stacked_images.items():
+ assert isinstance(k, str), f"{k} is not a string"
+ assert isinstance(v, torch.Tensor), f"{v} is not a torch tensor"
+ assert v.ndim == 4, f"{v} is not a 4D tensor"
+ assert v.dtype == torch.uint8, f"{v} is not a uint8 tensor"
+ assert v.shape[1] == 3, f"{v} is not a 3 channel tensor"
+
+ stacked_images = (
+ torch.stack([temporal_stacked_images[view] for view in image_keys], dim=1).flatten(0, 1).numpy()
+ ) # (T*V, C, H, W), processor expects numpy array
+
+ vlm_inputs = self._apply_vlm_processing(stacked_images, language)
+ return vlm_inputs
+
+ def save_pretrained(self, save_directory: str | Path) -> list[Path]:
+ save_directory = Path(save_directory)
+ save_directory.mkdir(parents=True, exist_ok=True)
+ main_config_file = save_directory / "processor_config.json"
+ statistics_file = save_directory / "statistics.json"
+ embodiment_id_file = save_directory / "embodiment_id.json"
+
+ config = {
+ "processor_class": self.__class__.__name__,
+ "processor_kwargs": {
+ "modality_configs": to_json_serializable(self.modality_configs),
+ "image_crop_size": self.image_crop_size,
+ "image_target_size": self.image_target_size,
+ "random_rotation_angle": self.random_rotation_angle,
+ "color_jitter_params": self.color_jitter_params,
+ "shortest_image_edge": self.shortest_image_edge,
+ "crop_fraction": self.crop_fraction,
+ "model_name": self.model_name,
+ "model_type": self.model_type,
+ "formalize_language": self.formalize_language,
+ "max_state_dim": self.max_state_dim,
+ "max_action_dim": self.max_action_dim,
+ "max_action_horizon": self.max_action_horizon,
+ "use_percentiles": self.use_percentiles,
+ "use_mean_std": self.use_mean_std,
+ "clip_outliers": self.clip_outliers,
+ "apply_sincos_state_encoding": self.apply_sincos_state_encoding,
+ "use_relative_action": self.use_relative_action,
+ "exclude_state": self.exclude_state,
+ },
+ }
+ with open(main_config_file, "w") as f:
+ json.dump(config, f, indent=2)
+ with open(statistics_file, "w") as f:
+ json.dump(
+ to_json_serializable(self.state_action_processor.statistics),
+ f,
+ indent=2,
+ )
+ with open(embodiment_id_file, "w") as f:
+ json.dump(self.embodiment_id_mapping, f, indent=2)
+ return [main_config_file, statistics_file, embodiment_id_file]
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: str | Path, **kwargs):
+ transformers_loading_kwargs = kwargs.pop("transformers_loading_kwargs", {"trust_remote_code": True})
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
+ config_file = pretrained_model_name_or_path / "processor_config.json"
+ statistics_file = pretrained_model_name_or_path / "statistics.json"
+ embodiment_id_file = pretrained_model_name_or_path / "embodiment_id.json"
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if not is_local:
+ config_file = Path(cached_file(pretrained_model_name_or_path, "processor_config.json"))
+ statistics_file = Path(cached_file(pretrained_model_name_or_path, "statistics.json"))
+ embodiment_id_file = Path(cached_file(pretrained_model_name_or_path, "embodiment_id.json"))
+
+ with open(config_file) as f:
+ config = json.load(f)
+ with open(statistics_file) as f:
+ statistics = json.load(f)
+ if embodiment_id_file.exists():
+ with open(embodiment_id_file) as f:
+ embodiment_id_mapping = json.load(f)
+ else:
+ embodiment_id_mapping = None
+ processor_kwargs = config["processor_kwargs"]
+ processor_kwargs["statistics"] = statistics
+ processor_kwargs["embodiment_id_mapping"] = embodiment_id_mapping
+
+ # Backfill missing fields from older checkpoints.
+ processor_kwargs.setdefault("model_name", "nvidia/Cosmos-Reason2-2B")
+ processor_kwargs.setdefault("model_type", "qwen")
+ processor_kwargs.setdefault("clip_outliers", True)
+
+ # Directly override other processor kwargs
+ if kwargs:
+ # Override modality configs while keeping pretrained embodiment configs
+ modality_configs = kwargs.pop("modality_configs", {})
+ for embodiment_tag, modality_config in modality_configs.items():
+ processor_kwargs["modality_configs"][embodiment_tag] = modality_config
+ override_keys = [
+ "random_rotation_angle",
+ "color_jitter_params",
+ "use_relative_action",
+ "exclude_state",
+ "use_mean_std",
+ "model_name",
+ "model_type",
+ "max_action_horizon",
+ "max_state_dim",
+ "max_action_dim",
+ ]
+ for key in override_keys:
+ if key in kwargs:
+ override = kwargs.pop(key)
+ if override is not None:
+ processor_kwargs[key] = override
+ return cls(**processor_kwargs, transformers_loading_kwargs=transformers_loading_kwargs)
+
+
+AutoProcessor.register("Gr00tN1d7", Gr00tN1d7Processor)
diff --git a/vllm_omni/diffusion/models/gr00t/pipeline_gr00t.py b/vllm_omni/diffusion/models/gr00t/pipeline_gr00t.py
new file mode 100644
index 00000000000..63f8e4ce83e
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/pipeline_gr00t.py
@@ -0,0 +1,117 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from collections.abc import Iterable, Mapping
+from typing import Any
+
+import numpy as np
+import torch
+from torch import nn
+from vllm.logger import init_logger
+
+from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
+from vllm_omni.diffusion.models.gr00t.policy import Gr00tPolicy
+from vllm_omni.diffusion.request import DUMMY_DIFFUSION_REQUEST_ID, OmniDiffusionRequest
+
+logger = init_logger(__name__)
+
+
+def _to_float32_action_dict(actions: Mapping[str, Any]) -> dict[str, np.ndarray]:
+ converted = {str(key): np.asarray(value, dtype=np.float32) for key, value in actions.items()}
+ if not converted:
+ raise RuntimeError("GR00T policy returned an empty action dict.")
+ return converted
+
+
+class Gr00tN1d7Pipeline(nn.Module):
+ """GR00T N1.7 policy pipeline backed by vLLM-Omni's local GR00T port.
+
+ vLLM-Omni owns the serving integration: OpenPI observations arrive through
+ `sampling_params.extra_args["robot_obs"]`, this pipeline runs GR00T policy
+ inference, and actions are returned through `DiffusionOutput.multimodal_output`.
+ """
+
+ EXTRA_BODY_PARAMS = frozenset({"robot_obs"})
+
+ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = "") -> None:
+ super().__init__()
+ model_config = od_config.model_config
+ self.model_path = od_config.model
+ self.embodiment_tag = str(model_config.get("embodiment_tag") or "OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT")
+ self.strict = bool(model_config.get("strict", True))
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ logger.info("Loading GR00T N1.7 policy from %s with embodiment_tag=%s", self.model_path, self.embodiment_tag)
+ self.policy = Gr00tPolicy(
+ model_path=self.model_path,
+ embodiment_tag=self.embodiment_tag,
+ device=self.device,
+ strict=self.strict,
+ )
+
+ def reset(self) -> dict[str, Any]:
+ return self.policy.reset() or {}
+
+ @property
+ def weights_sources(self) -> tuple[Any, ...]:
+ return ()
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ consumed = list(weights)
+ if consumed:
+ raise RuntimeError(
+ f"Gr00tN1d7Pipeline.load_weights received {len(consumed)} weight tensors; "
+ "weights_sources=() should prevent this. GR00T weights are loaded directly by Gr00tPolicy."
+ )
+ return set()
+
+ def _dummy_actions(self) -> dict[str, np.ndarray]:
+ embodiment_value = self.policy.embodiment_tag.value
+ action_config = self.policy.modality_configs["action"]
+ horizon = len(action_config.delta_indices)
+ norm_params = self.policy.processor.state_action_processor.norm_params[embodiment_value]["action"]
+ actions = {}
+ for key in action_config.modality_keys:
+ dim = norm_params[key]["dim"]
+ dim = int(dim.item() if hasattr(dim, "item") else dim)
+ actions[key] = np.zeros((1, horizon, dim), dtype=np.float32)
+ return actions
+
+ @torch.inference_mode()
+ def forward(self, req: OmniDiffusionRequest, **kwargs) -> DiffusionOutput:
+ del kwargs
+ extra_args = req.sampling_params.extra_args or {}
+ robot_obs = extra_args.get("robot_obs")
+ if robot_obs is None:
+ if req.request_id == DUMMY_DIFFUSION_REQUEST_ID:
+ return DiffusionOutput(multimodal_output={"actions": self._dummy_actions()})
+ return DiffusionOutput(error="Gr00tN1d7Pipeline.forward expects sampling_params.extra_args['robot_obs'].")
+ if not isinstance(robot_obs, Mapping):
+ return DiffusionOutput(error=f"robot_obs must be a dict, got {type(robot_obs).__name__}.")
+
+ if extra_args.get("reset"):
+ self.reset()
+
+ policy_obs = _normalize_observation(robot_obs, language_key=self.policy.language_key)
+ result = self.policy.get_action(policy_obs)
+ actions = result[0] if isinstance(result, tuple) else result
+ if not isinstance(actions, Mapping):
+ return DiffusionOutput(error=f"GR00T policy returned {type(actions).__name__}; expected dict actions.")
+ return DiffusionOutput(multimodal_output={"actions": _to_float32_action_dict(actions)})
+
+
+def _normalize_observation(robot_obs: Mapping[str, Any], *, language_key: str) -> dict[str, Any]:
+ obs: dict[str, Any] = {}
+ if "video" in robot_obs:
+ obs["video"] = robot_obs["video"]
+ elif "images" in robot_obs:
+ obs["video"] = robot_obs["images"]
+ if "state" in robot_obs:
+ obs["state"] = robot_obs["state"]
+ if "language" in robot_obs:
+ obs["language"] = robot_obs["language"]
+ else:
+ prompt = robot_obs.get("prompt")
+ if prompt is not None:
+ obs["language"] = {language_key: [[str(prompt)]]}
+ return obs
diff --git a/vllm_omni/diffusion/models/gr00t/policy.py b/vllm_omni/diffusion/models/gr00t/policy.py
new file mode 100644
index 00000000000..b37c2eb46ba
--- /dev/null
+++ b/vllm_omni/diffusion/models/gr00t/policy.py
@@ -0,0 +1,437 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import torch
+from transformers import AutoModel, AutoProcessor
+
+from vllm_omni.diffusion.models.gr00t.dataio.embodiment_tags import FINETUNE_ONLY_TAGS, POSTTRAIN_TAGS, EmbodimentTag
+from vllm_omni.diffusion.models.gr00t.dataio.types import MessageType, ModalityConfig, VLAStepData
+from vllm_omni.diffusion.models.gr00t.modeling.gr00t_n1d7 import Gr00tN1d7
+from vllm_omni.diffusion.models.gr00t.modeling.processing_gr00t_n1d7 import Gr00tN1d7Processor
+
+
+def _rec_to_dtype(value: Any, dtype: torch.dtype) -> Any:
+ """Recursively convert floating-point tensors in nested collator output."""
+ if isinstance(value, torch.Tensor):
+ return value.to(dtype=dtype) if torch.is_floating_point(value) else value
+ if isinstance(value, dict) or hasattr(value, "items"):
+ return {key: _rec_to_dtype(item, dtype) for key, item in value.items()}
+ if isinstance(value, list):
+ return [_rec_to_dtype(item, dtype) for item in value]
+ return value
+
+
+class Gr00tPolicy:
+ """Core policy class for Gr00t model inference.
+
+ This policy handles the end-to-end inference pipeline:
+ 1. Validates input observations
+ 2. Processes observations with pretrained VLA processor
+ 3. Runs model inference
+ 4. Decodes and returns actions
+
+ The policy expects observations with specific modalities (video, state, language)
+ and returns actions in the format defined by the model's modality configuration.
+ """
+
+ def __init__(
+ self,
+ embodiment_tag: EmbodimentTag | str,
+ model_path: str,
+ *,
+ device: int | str,
+ strict: bool = True,
+ ):
+ """Initialize the Gr00t Policy.
+
+ Args:
+ embodiment_tag: The embodiment tag defining the robot/environment type.
+ Accepts an EmbodimentTag enum or a string (resolved case-insensitively).
+ model_path: Path to the pretrained model checkpoint directory
+ device: Device to run the model on (e.g., 'cuda:0', 0, 'cpu')
+ strict: Whether to enforce strict input validation (default: True)
+ """
+ self.strict = strict
+ if isinstance(embodiment_tag, str):
+ embodiment_tag = EmbodimentTag.resolve(embodiment_tag)
+ model_dir = Path(model_path)
+
+ # Load the pretrained model and move to target device with bfloat16 precision
+ self.model: Gr00tN1d7 = AutoModel.from_pretrained(model_dir)
+ self.model.eval()
+ self.model.to(device=device, dtype=torch.bfloat16)
+
+ # Training saves processor files under a "processor/" subdirectory, but
+ # AutoProcessor expects them at the model root. Fall back to the
+ # subdirectory when the root lacks a processor_config.json.
+ processor_dir = (
+ model_dir / "processor"
+ if (model_dir / "processor").is_dir() and not (model_dir / "processor_config.json").exists()
+ else model_dir
+ )
+ self.processor: Gr00tN1d7Processor = AutoProcessor.from_pretrained(processor_dir)
+
+ # Store embodiment-specific configurations
+ self.embodiment_tag = embodiment_tag
+ all_modality_configs = self.processor.modality_configs
+ if self.embodiment_tag.value not in all_modality_configs:
+ # Map raw checkpoint tag values to user-friendly enum names where possible.
+ supported_lines = []
+ for tag_value in sorted(all_modality_configs.keys()):
+ enum_name = EmbodimentTag.reverse_lookup(tag_value)
+ if enum_name != tag_value:
+ supported_lines.append(f" {enum_name:30s} (--embodiment-tag {enum_name})")
+ else:
+ supported_lines.append(f" {tag_value:30s} (internal, no public enum)")
+ supported_str = "\n".join(supported_lines)
+
+ hint = ""
+ if self.embodiment_tag in POSTTRAIN_TAGS:
+ hint = (
+ f"\n\nHint: '{self.embodiment_tag.name}' is a posttrain tag that requires "
+ f"a finetuned checkpoint, not the base model. "
+ f"See the example READMEs for how to finetune and download checkpoints."
+ )
+ elif self.embodiment_tag in FINETUNE_ONLY_TAGS:
+ hint = (
+ f"\n\nHint: '{self.embodiment_tag.name}' is for finetuning custom robots. "
+ f"Use it with launch_finetune.py, not with the base model directly."
+ )
+
+ raise ValueError(
+ f"Embodiment tag '{self.embodiment_tag.name}' "
+ f"(value='{self.embodiment_tag.value}') is not supported "
+ f"by this checkpoint.\n\n"
+ f"Supported tags in this checkpoint:\n{supported_str}"
+ f"{hint}"
+ )
+ self.modality_configs = {
+ k: v for k, v in all_modality_configs[self.embodiment_tag.value].items() if k != "rl_info"
+ }
+ self.collate_fn = self.processor.collator
+
+ # Extract and validate language configuration
+ # Embodiments may define multiple language keys (e.g. paraphrases); use only the first at inference.
+ language_keys = self.modality_configs["language"].modality_keys
+ language_delta_indices = self.modality_configs["language"].delta_indices
+ assert len(language_keys) >= 1, "At least one language key is required"
+ assert len(language_delta_indices) == 1, "Only one language delta index is supported"
+ self.language_key = language_keys[0]
+
+ def _unbatch_observation(self, value: dict[str, Any]) -> list[dict[str, Any]]:
+ """Unbatch a batched observation into a list of single observations.
+
+ Args:
+ value: Batched observation with shape (B, ...) for each modality
+
+ Returns:
+ List of B observations, each with the batch dimension removed
+ """
+ unbatched_obs = []
+ # Infer batch size from the first video key
+ batch_size = value["video"][list(value["video"].keys())[0]].shape[0]
+
+ # Split each modality along the batch dimension
+ for i in range(batch_size):
+ unbatched_value = {
+ "video": {k: v[i] for k, v in value["video"].items()},
+ "state": {k: v[i] for k, v in value["state"].items()},
+ "language": {k: v[i] for k, v in value["language"].items()},
+ }
+ unbatched_obs.append(unbatched_value)
+ return unbatched_obs
+
+ def _to_vla_step_data(self, observation: dict[str, Any]) -> VLAStepData:
+ """Convert a single observation into a VLAStepData object for processing.
+
+ Args:
+ observation: Single observation dict with video, state, and language
+
+ Returns:
+ VLAStepData object ready for processor input
+ """
+ return VLAStepData(
+ images=observation["video"],
+ states=observation["state"],
+ actions={}, # No ground truth actions during inference
+ text=observation["language"][self.language_key][0],
+ embodiment=self.embodiment_tag,
+ )
+
+ def check_observation(self, observation: dict[str, Any]) -> None:
+ """Validate that the observation has the correct structure and types.
+
+ This method ensures that all required modalities are present and that their
+ data types, shapes, and dimensions match the model's expectations.
+
+ Expected observation structure:
+ - video: dict[str, np.ndarray[np.uint8, (B, T, H, W, C)]]
+ - B: batch size
+ - T: temporal horizon (number of frames)
+ - H, W: image height and width
+ - C: number of channels (must be 3 for RGB)
+ - state: dict[str, np.ndarray[np.float32, (B, T, D)]]
+ - B: batch size
+ - T: temporal horizon (number of state observations)
+ - D: state dimension
+ - language: dict[str, list[list[str]]]
+ - Shape: (B, T) where each element is a string
+ - T: temporal horizon (typically 1 for language)
+
+ Args:
+ observation: Dictionary containing video, state, and language modalities
+
+ Raises:
+ AssertionError: If any validation check fails
+ """
+ # Check that observation contains all required top-level modality keys
+ for modality in ["video", "state", "language"]:
+ assert modality in observation, f"Observation must contain a '{modality}' key"
+ assert isinstance(observation[modality], dict), (
+ f"Observation '{modality}' must be a dictionary. "
+ f"Got {type(observation[modality])}: {observation[modality]}"
+ )
+
+ # Track batch size across modalities to ensure consistency
+ bs = -1
+
+ for video_key in self.modality_configs["video"].modality_keys:
+ assert video_key in observation["video"], f"Video key '{video_key}' must be in observation"
+
+ # Set or verify batch size consistency across all video keys
+ if bs == -1:
+ bs = len(observation["video"][video_key])
+ else:
+ assert len(observation["video"][video_key]) == bs, (
+ f"Video key '{video_key}' must have batch size {bs}. Got {len(observation['video'][video_key])}"
+ )
+
+ batched_video = observation["video"][video_key]
+
+ # Verify data type is numpy array
+ assert isinstance(batched_video, np.ndarray), (
+ f"Video key '{video_key}' must be a numpy array. Got {type(batched_video)}"
+ )
+
+ # Verify dtype is uint8 (standard for image data, range 0-255)
+ assert batched_video.dtype == np.uint8, (
+ f"Video key '{video_key}' must be a numpy array of type np.uint8. Got {batched_video.dtype}"
+ )
+
+ # Verify shape has 5 dimensions: (B, T, H, W, C)
+ assert batched_video.ndim == 5, (
+ f"Video key '{video_key}' must be a numpy array of shape (B, T, H, W, C), got {batched_video.shape}"
+ )
+
+ # Verify temporal dimension matches the expected horizon from config
+ assert batched_video.shape[1] == len(self.modality_configs["video"].delta_indices), (
+ f"Video key '{video_key}'s horizon must be "
+ f"{len(self.modality_configs['video'].delta_indices)}. Got {batched_video.shape[1]}"
+ )
+
+ # Verify channel dimension is 3 (RGB images)
+ assert batched_video.shape[-1] == 3, (
+ f"Video key '{video_key}'s channel 'C' must be 3. Got {batched_video.shape[-1]}"
+ )
+
+ for state_key in self.modality_configs["state"].modality_keys:
+ # Check that the expected state key exists in the observation
+ # Must happen before indexing; see video validation above.
+ assert state_key in observation["state"], f"State key '{state_key}' must be in observation"
+
+ # Set or verify batch size consistency across all state keys
+ if bs == -1:
+ bs = len(observation["state"][state_key])
+ else:
+ assert len(observation["state"][state_key]) == bs, (
+ f"State key '{state_key}' must have batch size {bs}. Got {len(observation['state'][state_key])}"
+ )
+
+ batched_state = observation["state"][state_key]
+
+ # Verify data type is numpy array
+ assert isinstance(batched_state, np.ndarray), (
+ f"State key '{state_key}' must be a numpy array. Got {type(batched_state)}"
+ )
+
+ # Verify dtype is float32 (standard for continuous state values)
+ assert batched_state.dtype == np.float32, (
+ f"State key '{state_key}' must be a numpy array of type np.float32. Got {batched_state.dtype}"
+ )
+
+ # Verify shape has 3 dimensions: (B, T, D)
+ assert batched_state.ndim == 3, (
+ f"State key '{state_key}' must be a numpy array of shape (B, T, D), got {batched_state.shape}"
+ )
+
+ # Verify temporal dimension matches the expected horizon from config
+ assert batched_state.shape[1] == len(self.modality_configs["state"].delta_indices), (
+ f"State key '{state_key}'s horizon must be "
+ f"{len(self.modality_configs['state'].delta_indices)}. Got {batched_state.shape[1]}"
+ )
+
+ for language_key in self.modality_configs["language"].modality_keys:
+ # Check that the expected language key exists in the observation
+ # Must happen before indexing; see video validation above.
+ assert language_key in observation["language"], f"Language key '{language_key}' must be in observation"
+
+ # Set or verify batch size consistency (language uses len instead of .shape)
+ if bs == -1:
+ bs = len(observation["language"][language_key])
+ else:
+ assert len(observation["language"][language_key]) == bs, (
+ f"Language key '{language_key}' must have batch size {bs}. "
+ f"Got {len(observation['language'][language_key])}"
+ )
+
+ batched_language: list[list[str]] = observation["language"][language_key]
+
+ # Verify outer structure is a list (batch dimension)
+ assert isinstance(batched_language, list), (
+ f"Language key '{language_key}' must be a list. Got {type(batched_language)}"
+ )
+
+ # Validate each batch item
+ for batch_item in batched_language:
+ # Verify temporal dimension matches expected horizon
+ assert len(batch_item) == len(self.modality_configs["language"].delta_indices), (
+ f"Language key '{language_key}'s horizon must be "
+ f"{len(self.modality_configs['language'].delta_indices)}. Got {len(batched_language)}"
+ )
+
+ # Verify inner structure is also a list (temporal dimension)
+ assert isinstance(batch_item, list), f"Language batch item must be a list. Got {type(batch_item)}"
+
+ # Current implementation expects exactly one language instruction per timestep
+ assert len(batch_item) == 1, f"Language batch item must have exactly one item. Got {len(batch_item)}"
+
+ # Verify the instruction itself is a string
+ assert isinstance(batch_item[0], str), (
+ f"Language batch item must be a string. Got {type(batch_item[0])}"
+ )
+
+ def _get_action(
+ self, observation: dict[str, Any], options: dict[str, Any] | None = None
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
+ """Internal method to compute actions from observations.
+
+ Pipeline:
+ 1. Unbatch observations into individual samples
+ 2. Convert each to VLAStepData and process
+ 3. Collate into model input batch
+ 4. Run model inference
+ 5. Decode and unnormalize actions
+
+ Args:
+ observation: Batched observation dictionary
+ options: Optional parameters (currently unused)
+
+ Returns:
+ Tuple of (actions_dict, info_dict)
+ """
+ # Step 1: Split batched observation into individual observations
+ unbatched_observations = self._unbatch_observation(observation)
+ processed_inputs = []
+
+ # Step 2: Process each observation through the VLA processor
+ states = []
+ for obs in unbatched_observations:
+ vla_step_data = self._to_vla_step_data(obs)
+ states.append(vla_step_data.states) # dict[str, np.ndarray[np.float32, (T, D)]]
+ messages = [{"type": MessageType.EPISODE_STEP.value, "content": vla_step_data}]
+ processed_inputs.append(self.processor(messages))
+
+ # Step 3: Collate processed inputs into a single batch for model
+ collated_inputs = self.collate_fn(processed_inputs)
+ collated_inputs = _rec_to_dtype(collated_inputs, dtype=torch.bfloat16)
+
+ # Step 4: Run model inference to predict actions
+ with torch.inference_mode():
+ model_pred = self.model.get_action(**collated_inputs)
+ normalized_action = model_pred["action_pred"].float()
+
+ # Step 5: Decode actions from normalized space back to physical units
+ batched_states = {}
+ for k in self.modality_configs["state"].modality_keys:
+ batched_states[k] = np.stack([s[k] for s in states], axis=0) # (B, T, D)
+ unnormalized_action = self.processor.decode_action(
+ normalized_action.cpu().numpy(), self.embodiment_tag, batched_states
+ )
+
+ # Cast all actions to float32 for consistency
+ casted_action = {key: value.astype(np.float32) for key, value in unnormalized_action.items()}
+ return casted_action, {}
+
+ def check_action(self, action: dict[str, Any]) -> None:
+ """Validate that the action has the correct structure and types.
+
+ This method ensures that all required action keys are present and that their
+ data types, shapes, and dimensions match the model's action space.
+
+ Expected action structure:
+ - action: dict[str, np.ndarray[np.float32, (B, T, D)]]
+ - B: batch size
+ - T: action horizon (number of future action steps)
+ - D: action dimension (e.g., joint positions, velocities, gripper state)
+
+ Args:
+ action: Dictionary containing action arrays for each action key
+
+ Raises:
+ AssertionError: If any validation check fails
+ """
+ # Validate each action key defined in the modality config
+ for action_key in self.modality_configs["action"].modality_keys:
+ # Check that the expected action key exists
+ assert action_key in action, f"Action key '{action_key}' must be in action"
+
+ action_arr = action[action_key]
+
+ # Verify data type is numpy array
+ assert isinstance(action_arr, np.ndarray), (
+ f"Action key '{action_key}' must be a numpy array. Got {type(action_arr)}"
+ )
+
+ # Verify dtype is float32 (standard for continuous actions)
+ assert action_arr.dtype == np.float32, (
+ f"Action key '{action_key}' must be a numpy array of type np.float32. Got {action_arr.dtype}"
+ )
+
+ # Verify shape has 3 dimensions: (B, T, D)
+ assert action_arr.ndim == 3, (
+ f"Action key '{action_key}' must be a numpy array of shape (B, T, D), got {action_arr.shape}"
+ )
+
+ # Verify action horizon matches the expected temporal dimension from config
+ assert action_arr.shape[1] == len(self.modality_configs["action"].delta_indices), (
+ f"Action key '{action_key}'s horizon must be "
+ f"{len(self.modality_configs['action'].delta_indices)}. Got {action_arr.shape[1]}"
+ )
+
+ def get_action(
+ self, observation: dict[str, Any], options: dict[str, Any] | None = None
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
+ if self.strict:
+ self.check_observation(observation)
+ action, info = self._get_action(observation, options)
+ if self.strict:
+ self.check_action(action)
+ return action, info
+
+ def get_modality_config(self) -> dict[str, ModalityConfig]:
+ return self.modality_configs
+
+ def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]:
+ """Reset the policy to its initial state.
+
+ Args:
+ options: Dictionary containing the options for the reset
+
+ Returns:
+ Dictionary containing the info after resetting the policy
+ """
+ return {}
diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py
index a381640b507..bbae141116b 100644
--- a/vllm_omni/diffusion/registry.py
+++ b/vllm_omni/diffusion/registry.py
@@ -151,6 +151,11 @@
"pipeline_internvla_a1",
"InternVLAA1Pipeline",
),
+ "Gr00tN1d7Pipeline": (
+ "gr00t",
+ "pipeline_gr00t",
+ "Gr00tN1d7Pipeline",
+ ),
"LongCatImageEditPipeline": (
"longcat_image",
"pipeline_longcat_image_edit",
diff --git a/vllm_omni/diffusion/utils/flow_matching.py b/vllm_omni/diffusion/utils/flow_matching.py
new file mode 100644
index 00000000000..e81d405ec23
--- /dev/null
+++ b/vllm_omni/diffusion/utils/flow_matching.py
@@ -0,0 +1,28 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import math
+
+import torch
+from torch import nn
+
+
+def swish(x: torch.Tensor) -> torch.Tensor:
+ return x * torch.sigmoid(x)
+
+
+class SinusoidalPositionalEncoding(nn.Module):
+ def __init__(self, embedding_dim: int) -> None:
+ super().__init__()
+ self.embedding_dim = embedding_dim
+
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
+ timesteps = timesteps.float()
+ B, T = timesteps.shape
+ device = timesteps.device
+
+ half_dim = self.embedding_dim // 2
+ exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (math.log(10000.0) / half_dim)
+ freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
+ enc = torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1) # (B, T, embedding_dim)
+ return enc
diff --git a/vllm_omni/model_executor/models/gr00t/__init__.py b/vllm_omni/model_executor/models/gr00t/__init__.py
new file mode 100644
index 00000000000..208f01a7cb5
--- /dev/null
+++ b/vllm_omni/model_executor/models/gr00t/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
diff --git a/vllm_omni/model_executor/models/gr00t/pipeline.py b/vllm_omni/model_executor/models/gr00t/pipeline.py
new file mode 100644
index 00000000000..38a76f29769
--- /dev/null
+++ b/vllm_omni/model_executor/models/gr00t/pipeline.py
@@ -0,0 +1,26 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""GR00T N1.7 single-stage policy topology."""
+
+from vllm_omni.config.stage_config import (
+ PipelineConfig,
+ StageExecutionType,
+ StagePipelineConfig,
+)
+
+GR00T_N1D7_PIPELINE = PipelineConfig(
+ model_type="Gr00tN1d7",
+ model_arch="Gr00tN1d7Pipeline",
+ hf_architectures=("Gr00tN1d7",),
+ stages=(
+ StagePipelineConfig(
+ stage_id=0,
+ model_stage="diffusion",
+ execution_type=StageExecutionType.DIFFUSION,
+ input_sources=(),
+ final_output=True,
+ final_output_type="actions",
+ model_arch="Gr00tN1d7Pipeline",
+ ),
+ ),
+)