diff --git a/docs/source/unitree_g1.mdx b/docs/source/unitree_g1.mdx index bdc7eb33d93..e6bffdf1b9f 100644 --- a/docs/source/unitree_g1.mdx +++ b/docs/source/unitree_g1.mdx @@ -7,7 +7,7 @@ This guide covers the complete setup process for the Unitree G1 humanoid, from i We support both 29 and 23 DOF G1 EDU version. We introduce: - **`unitree g1` robot class, handling low level read/write from/to the humanoid** -- **ZMQ socket bridge** for remote communication over wlan, allowing for remote policy deployment as well as over eth or directly on the Orin +- **ZMQ socket bridge** for remote communication and camera streaming, allowing for remote policy deployment over wlan, eth or directly on the robot - **Locomotion policies** from NVIDIA gr00t and Amazon FAR Holosoma - **Simulation mode** for testing policies without the physical robot in mujoco @@ -110,7 +110,7 @@ ssh unitree@ # Password: 123 ``` -Replace `` with your robot's actual WiFi IP address (e.g., `172.18.129.215`). +Replace `` with your robot's actual WiFi IP address. --- @@ -188,7 +188,7 @@ Press `Ctrl+C` to stop the policy. ## Running in Simulation Mode (MuJoCo) -You can now test and develop policies without a physical robot using MuJoCo. To do so simply set `is_simulation=True` in config. +You can now test policies before unleashing them on the physical robot using MuJoCo. To do so simply set `is_simulation=True` in config. ## Additional Resources diff --git a/examples/unitree_g1/gr00t_locomotion.py b/examples/unitree_g1/gr00t_locomotion.py index 30e5a27e671..0123b520673 100644 --- a/examples/unitree_g1/gr00t_locomotion.py +++ b/examples/unitree_g1/gr00t_locomotion.py @@ -111,34 +111,29 @@ def __init__(self, policy_balance, policy_walk, robot, config): def run_step(self): # Get current observation - robot_state = self.robot.get_observation() + obs = self.robot.get_observation() - if robot_state is None: + if not obs: return # Get command from remote controller - if robot_state.wireless_remote is not None: - self.robot.remote_controller.set(robot_state.wireless_remote) - if self.robot.remote_controller.button[0]: # R1 - raise waist - self.groot_height_cmd += 0.001 - self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) - if self.robot.remote_controller.button[4]: # R2 - lower waist - self.groot_height_cmd -= 0.001 - self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) - else: - self.robot.remote_controller.lx = 0.0 - self.robot.remote_controller.ly = 0.0 - self.robot.remote_controller.rx = 0.0 - self.robot.remote_controller.ry = 0.0 - - self.cmd[0] = self.robot.remote_controller.ly # Forward/backward - self.cmd[1] = self.robot.remote_controller.lx * -1 # Left/right - self.cmd[2] = self.robot.remote_controller.rx * -1 # Rotation rate - - # Get joint positions and velocities - for i in range(29): - self.groot_qj_all[i] = robot_state.motor_state[i].q - self.groot_dqj_all[i] = robot_state.motor_state[i].dq + if obs["remote.buttons"][0]: # R1 - raise waist + self.groot_height_cmd += 0.001 + self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) + if obs["remote.buttons"][4]: # R2 - lower waist + self.groot_height_cmd -= 0.001 + self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) + + self.cmd[0] = obs["remote.ly"] # Forward/backward + self.cmd[1] = obs["remote.lx"] * -1 # Left/right + self.cmd[2] = obs["remote.rx"] * -1 # Rotation rate + + # Get joint positions and velocities from flat dict + for motor in G1_29_JointIndex: + name = motor.name + idx = motor.value + self.groot_qj_all[idx] = obs[f"{name}.q"] + self.groot_dqj_all[idx] = obs[f"{name}.dq"] # Adapt observation for g1_23dof for idx in MISSING_JOINTS: @@ -150,8 +145,8 @@ def run_step(self): dqj_obs = self.groot_dqj_all.copy() # Express IMU data in gravity frame of reference - quat = robot_state.imu_state.quaternion - ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32) + quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]] + ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32) gravity_orientation = self.robot.get_gravity_orientation(quat) # Scale joint positions and velocities before policy inference @@ -219,6 +214,8 @@ def run(repo_id: str = DEFAULT_GROOT_REPO_ID) -> None: config = UnitreeG1Config() robot = UnitreeG1(config) + robot.connect() + # Initialize gr00T locomotion controller groot_controller = GrootLocomotionController( policy_balance=policy_balance, @@ -234,7 +231,7 @@ def run(repo_id: str = DEFAULT_GROOT_REPO_ID) -> None: logger.info("Press Ctrl+C to stop") # Run step - while True: + while not robot._shutdown_event.is_set(): start_time = time.time() groot_controller.run_step() elapsed = time.time() - start_time diff --git a/examples/unitree_g1/holosoma_locomotion.py b/examples/unitree_g1/holosoma_locomotion.py index 017f7835a94..3a07023de66 100644 --- a/examples/unitree_g1/holosoma_locomotion.py +++ b/examples/unitree_g1/holosoma_locomotion.py @@ -126,24 +126,23 @@ def __init__(self, policy, robot, kp: np.ndarray, kd: np.ndarray): def run_step(self): # Get current observation - robot_state = self.robot.get_observation() + obs = self.robot.get_observation() - if robot_state is None: + if not obs: return # Get command from remote controller - if robot_state.wireless_remote is not None: - self.robot.remote_controller.set(robot_state.wireless_remote) - - ly = self.robot.remote_controller.ly if abs(self.robot.remote_controller.ly) > 0.1 else 0.0 - lx = self.robot.remote_controller.lx if abs(self.robot.remote_controller.lx) > 0.1 else 0.0 - rx = self.robot.remote_controller.rx if abs(self.robot.remote_controller.rx) > 0.1 else 0.0 + ly = obs["remote.ly"] if abs(obs["remote.ly"]) > 0.1 else 0.0 + lx = obs["remote.lx"] if abs(obs["remote.lx"]) > 0.1 else 0.0 + rx = obs["remote.rx"] if abs(obs["remote.rx"]) > 0.1 else 0.0 self.cmd[:] = [ly, -lx, -rx] # Get joint positions and velocities - for i in range(29): - self.qj[i] = robot_state.motor_state[i].q - self.dqj[i] = robot_state.motor_state[i].dq + for motor in G1_29_JointIndex: + name = motor.name + idx = motor.value + self.qj[idx] = obs[f"{name}.q"] + self.dqj[idx] = obs[f"{name}.dq"] # Adapt observation for g1_23dof for idx in MISSING_JOINTS: @@ -151,8 +150,8 @@ def run_step(self): self.dqj[idx] = 0.0 # Express IMU data in gravity frame of reference - quat = robot_state.imu_state.quaternion - ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32) + quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]] + ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32) gravity = self.robot.get_gravity_orientation(quat) # Scale joint positions and velocities before policy inference @@ -220,6 +219,7 @@ def run(repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, policy_type: str = "fastsac") - # Initialize robot config = UnitreeG1Config() robot = UnitreeG1(config) + robot.connect() holosoma_controller = HolosomaLocomotionController(policy, robot, kp, kd) @@ -230,7 +230,7 @@ def run(repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, policy_type: str = "fastsac") - logger.info("Press Ctrl+C to stop") # Run step - while True: + while not robot._shutdown_event.is_set(): start_time = time.time() holosoma_controller.run_step() elapsed = time.time() - start_time diff --git a/src/lerobot/cameras/utils.py b/src/lerobot/cameras/utils.py index 1b2d386d6b3..c0e7b6284ae 100644 --- a/src/lerobot/cameras/utils.py +++ b/src/lerobot/cameras/utils.py @@ -43,6 +43,11 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s cameras[key] = Reachy2Camera(cfg) + elif cfg.type == "zmq": + from .zmq.camera_zmq import ZMQCamera + + cameras[key] = ZMQCamera(cfg) + else: try: cameras[key] = cast(Camera, make_device_from_device_class(cfg)) diff --git a/src/lerobot/cameras/zmq/__init__.py b/src/lerobot/cameras/zmq/__init__.py new file mode 100644 index 00000000000..d760c5325f6 --- /dev/null +++ b/src/lerobot/cameras/zmq/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .camera_zmq import ZMQCamera +from .configuration_zmq import ZMQCameraConfig + +__all__ = ["ZMQCamera", "ZMQCameraConfig"] diff --git a/src/lerobot/cameras/zmq/camera_zmq.py b/src/lerobot/cameras/zmq/camera_zmq.py new file mode 100644 index 00000000000..1a4155f4bbc --- /dev/null +++ b/src/lerobot/cameras/zmq/camera_zmq.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +ZMQCamera - Captures frames from remote cameras via ZeroMQ using JSON protocol in the +following format: + { + "timestamps": {"camera_name": float}, + "images": {"camera_name": ""} + } +""" + +import base64 +import json +import logging +import time +from threading import Event, Lock, Thread +from typing import Any + +import cv2 +import numpy as np +from numpy.typing import NDArray + +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +from ..camera import Camera +from ..configs import ColorMode +from .configuration_zmq import ZMQCameraConfig + +logger = logging.getLogger(__name__) + + +class ZMQCamera(Camera): + """ + Example usage: + ```python + from lerobot.cameras.zmq import ZMQCamera, ZMQCameraConfig + + config = ZMQCameraConfig(server_address="192.168.123.164", port=5555, camera_name="head_camera") + camera = ZMQCamera(config) + camera.connect() + frame = camera.read() + camera.disconnect() + ``` + """ + + def __init__(self, config: ZMQCameraConfig): + super().__init__(config) + import zmq + + self.config = config + self.server_address = config.server_address + self.port = config.port + self.camera_name = config.camera_name + self.color_mode = config.color_mode + self.timeout_ms = config.timeout_ms + + self.context: zmq.Context | None = None + self.socket: zmq.Socket | None = None + self._connected = False + + self.thread: Thread | None = None + self.stop_event: Event | None = None + self.frame_lock: Lock = Lock() + self.latest_frame: NDArray[Any] | None = None + self.new_frame_event: Event = Event() + + def __str__(self) -> str: + return f"ZMQCamera({self.camera_name}@{self.server_address}:{self.port})" + + @property + def is_connected(self) -> bool: + return self._connected and self.context is not None and self.socket is not None + + def connect(self, warmup: bool = True) -> None: + """Connect to ZMQ camera server.""" + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} is already connected.") + + logger.info(f"Connecting to {self}...") + + try: + import zmq + + self.context = zmq.Context() + self.socket = self.context.socket(zmq.SUB) + self.socket.setsockopt_string(zmq.SUBSCRIBE, "") + self.socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms) + self.socket.setsockopt(zmq.CONFLATE, True) + self.socket.connect(f"tcp://{self.server_address}:{self.port}") + self._connected = True + + # Auto-detect resolution + if self.width is None or self.height is None: + h, w = self.read().shape[:2] + self.height = h + self.width = w + logger.info(f"{self} resolution: {w}x{h}") + + logger.info(f"{self} connected.") + + if warmup: + time.sleep(0.1) + + except Exception as e: + self._cleanup() + raise RuntimeError(f"Failed to connect to {self}: {e}") from e + + def _cleanup(self): + """Clean up ZMQ resources.""" + self._connected = False + if self.socket: + self.socket.close() + self.socket = None + if self.context: + self.context.term() + self.context = None + + @staticmethod + def find_cameras() -> list[dict[str, Any]]: + """ZMQ cameras require manual configuration (server address/port).""" + return [] + + def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: + """ + Read a single frame from the ZMQ camera. + + Returns: + np.ndarray: Decoded frame (height, width, 3) + """ + if not self.is_connected or self.socket is None: + raise DeviceNotConnectedError(f"{self} is not connected.") + + try: + message = self.socket.recv_string() + except Exception as e: + if type(e).__name__ == "Again": + raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e + raise + + # Decode JSON message + data = json.loads(message) + + if "images" not in data: + raise RuntimeError(f"{self} invalid message: missing 'images' key") + + images = data["images"] + + # Get image by camera name or first available + if self.camera_name in images: + img_b64 = images[self.camera_name] + elif images: + img_b64 = next(iter(images.values())) + else: + raise RuntimeError(f"{self} no images in message") + + # Decode base64 JPEG + img_bytes = base64.b64decode(img_b64) + frame = cv2.imdecode(np.frombuffer(img_bytes, np.uint8), cv2.IMREAD_COLOR) + + if frame is None: + raise RuntimeError(f"{self} failed to decode image") + + return frame + + def _read_loop(self) -> None: + while self.stop_event and not self.stop_event.is_set(): + try: + frame = self.read() + with self.frame_lock: + self.latest_frame = frame + self.new_frame_event.set() + except DeviceNotConnectedError: + break + except TimeoutError: + pass + except Exception as e: + logger.warning(f"Read error: {e}") + + def _start_read_thread(self) -> None: + if self.thread and self.thread.is_alive(): + return + self.stop_event = Event() + self.thread = Thread(target=self._read_loop, daemon=True) + self.thread.start() + + def _stop_read_thread(self) -> None: + if self.stop_event: + self.stop_event.set() + if self.thread and self.thread.is_alive(): + self.thread.join(timeout=2.0) + self.thread = None + self.stop_event = None + + def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]: + """Read latest frame asynchronously (non-blocking).""" + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if not self.thread or not self.thread.is_alive(): + self._start_read_thread() + + if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0): + raise TimeoutError(f"{self} async_read timeout after {timeout_ms}ms") + + with self.frame_lock: + frame = self.latest_frame + self.new_frame_event.clear() + + if frame is None: + raise RuntimeError(f"{self} no frame available") + + return frame + + def disconnect(self) -> None: + """Disconnect from ZMQ camera.""" + if not self.is_connected and not self.thread: + raise DeviceNotConnectedError(f"{self} not connected.") + + self._stop_read_thread() + self._cleanup() + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/cameras/zmq/configuration_zmq.py b/src/lerobot/cameras/zmq/configuration_zmq.py new file mode 100644 index 00000000000..027ae12b5b5 --- /dev/null +++ b/src/lerobot/cameras/zmq/configuration_zmq.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..configs import CameraConfig, ColorMode + +__all__ = ["ZMQCameraConfig", "ColorMode"] + + +@CameraConfig.register_subclass("zmq") +@dataclass +class ZMQCameraConfig(CameraConfig): + server_address: str + port: int = 5555 + camera_name: str = "zmq_camera" + color_mode: ColorMode = ColorMode.RGB + timeout_ms: int = 5000 + + def __post_init__(self) -> None: + if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): + raise ValueError( + f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." + ) + + if self.timeout_ms <= 0: + raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.") + + if not self.server_address: + raise ValueError("`server_address` cannot be empty.") + + if self.port <= 0 or self.port > 65535: + raise ValueError(f"`port` must be between 1 and 65535, but {self.port} is provided.") diff --git a/src/lerobot/cameras/zmq/image_server.py b/src/lerobot/cameras/zmq/image_server.py new file mode 100644 index 00000000000..2da366cefcb --- /dev/null +++ b/src/lerobot/cameras/zmq/image_server.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Streams camera images over ZMQ. +Uses lerobot's OpenCVCamera for capture, encodes images to base64 and sends them over ZMQ. +""" + +import base64 +import contextlib +import json +import logging +import time +from collections import deque + +import cv2 +import numpy as np +import zmq + +from lerobot.cameras.configs import ColorMode +from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig + +logger = logging.getLogger(__name__) + + +def encode_image(image: np.ndarray, quality: int = 80) -> str: + """Encode RGB image to base64 JPEG string.""" + _, buffer = cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), quality]) + return base64.b64encode(buffer).decode("utf-8") + + +class ImageServer: + def __init__(self, config: dict, port: int = 5555): + self.fps = config.get("fps", 30) + self.cameras: dict[str, OpenCVCamera] = {} + + for name, cfg in config.get("cameras", {}).items(): + shape = cfg.get("shape", [480, 640]) + cam_config = OpenCVCameraConfig( + index_or_path=cfg.get("device_id", 0), + fps=self.fps, + width=shape[1], + height=shape[0], + color_mode=ColorMode.RGB, + ) + camera = OpenCVCamera(cam_config) + camera.connect() + self.cameras[name] = camera + logger.info(f"Camera {name}: {shape[1]}x{shape[0]}") + + # ZMQ PUB socket + self.context = zmq.Context() + self.socket = self.context.socket(zmq.PUB) + self.socket.setsockopt(zmq.SNDHWM, 20) + self.socket.setsockopt(zmq.LINGER, 0) + self.socket.bind(f"tcp://*:{port}") + + logger.info(f"ImageServer running on port {port}") + + def run(self): + frame_count = 0 + frame_times = deque(maxlen=60) + + try: + while True: + t0 = time.time() + + # Build message + message = {"timestamps": {}, "images": {}} + for name, cam in self.cameras.items(): + frame = cam.read() # Returns RGB + message["timestamps"][name] = time.time() + message["images"][name] = encode_image(frame) + + # Send as JSON string (suppress if buffer full) + with contextlib.suppress(zmq.Again): + self.socket.send_string(json.dumps(message), zmq.NOBLOCK) + + frame_count += 1 + frame_times.append(time.time() - t0) + + if frame_count % 60 == 0: + logger.debug(f"FPS: {len(frame_times) / sum(frame_times):.1f}") + + sleep = (1.0 / self.fps) - (time.time() - t0) + if sleep > 0: + time.sleep(sleep) + + except KeyboardInterrupt: + pass + finally: + for cam in self.cameras.values(): + cam.disconnect() + self.socket.close() + self.context.term() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + config = {"fps": 30, "cameras": {"head_camera": {"device_id": 4, "shape": [480, 640]}}} + ImageServer(config, port=5555).run() diff --git a/src/lerobot/robots/unitree_g1/config_unitree_g1.py b/src/lerobot/robots/unitree_g1/config_unitree_g1.py index c66edbd1cb1..0b163019dc9 100644 --- a/src/lerobot/robots/unitree_g1/config_unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/config_unitree_g1.py @@ -16,6 +16,8 @@ from dataclasses import dataclass, field +from lerobot.cameras import CameraConfig + from ..config import RobotConfig _GAINS: dict[str, dict[str, list[float]]] = { @@ -60,3 +62,6 @@ class UnitreeG1Config(RobotConfig): # Socket config for ZMQ bridge robot_ip: str = "192.168.123.164" # default G1 IP + + # Cameras (ZMQ-based remote cameras) + cameras: dict[str, CameraConfig] = field(default_factory=dict) diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index 5bc4f31101b..1764f31b5c8 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -23,13 +23,8 @@ from typing import Any import numpy as np -from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_ -from unitree_sdk2py.idl.unitree_hg.msg.dds_ import ( - LowCmd_ as hg_LowCmd, - LowState_ as hg_LowState, -) -from unitree_sdk2py.utils.crc import CRC +from lerobot.cameras.utils import make_cameras_from_configs from lerobot.envs.factory import make_env from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex @@ -43,8 +38,6 @@ kTopicLowCommand_Debug = "rt/lowcmd" kTopicLowState = "rt/lowstate" -G1_29_Num_Motors = 29 - @dataclass class MotorState: @@ -66,28 +59,12 @@ class IMUState: # g1 observation class @dataclass class G1_29_LowState: # noqa: N801 - motor_state: list[MotorState] = field( - default_factory=lambda: [MotorState() for _ in range(G1_29_Num_Motors)] - ) + motor_state: list[MotorState] = field(default_factory=lambda: [MotorState() for _ in G1_29_JointIndex]) imu_state: IMUState = field(default_factory=IMUState) wireless_remote: Any = None # Raw wireless remote data mode_machine: int = 0 # Robot mode -class DataBuffer: - def __init__(self): - self.data = None - self.lock = threading.Lock() - - def get_data(self): - with self.lock: - return self.data - - def set_data(self, data): - with self.lock: - self.data = data - - class UnitreeG1(Robot): config_class = UnitreeG1Config name = "unitree_g1" @@ -117,9 +94,12 @@ def __init__(self, config: UnitreeG1Config): logger.info("Initialize UnitreeG1...") self.config = config - self.control_dt = config.control_dt + # Initialize cameras config (ZMQ-based) - actual connection in connect() + self._cameras = make_cameras_from_configs(config.cameras) + + # Import channel classes based on mode if config.is_simulation: from unitree_sdk2py.core.channel import ( ChannelFactoryInitialize, @@ -133,62 +113,33 @@ def __init__(self, config: UnitreeG1Config): ChannelSubscriber, ) - # connect robot - self.ChannelFactoryInitialize = ChannelFactoryInitialize - self.connect() + # Store for use in connect() + self._ChannelFactoryInitialize = ChannelFactoryInitialize + self._ChannelPublisher = ChannelPublisher + self._ChannelSubscriber = ChannelSubscriber - # initialize direct motor control interface - self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd) - self.lowcmd_publisher.Init() - self.lowstate_subscriber = ChannelSubscriber(kTopicLowState, hg_LowState) - self.lowstate_subscriber.Init() - self.lowstate_buffer = DataBuffer() - - # initialize subscribe thread to read robot state + # Initialize state variables + self.sim_env = None + self._env_wrapper = None + self._lowstate = None self._shutdown_event = threading.Event() - self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state) - self.subscribe_thread.start() - - while not self.is_connected: - time.sleep(0.1) - - # initialize hg's lowcmd msg - self.crc = CRC() - self.msg = unitree_hg_msg_dds__LowCmd_() - self.msg.mode_pr = 0 - - # Wait for first state message to arrive - lowstate = None - while lowstate is None: - lowstate = self.lowstate_buffer.get_data() - if lowstate is None: - time.sleep(0.01) - logger.warning("[UnitreeG1] Waiting for robot state...") - logger.warning("[UnitreeG1] Connected to robot.") - self.msg.mode_machine = lowstate.mode_machine - - # initialize all motors with unified kp/kd from config - self.kp = np.array(config.kp, dtype=np.float32) - self.kd = np.array(config.kd, dtype=np.float32) - - for id in G1_29_JointIndex: - self.msg.motor_cmd[id].mode = 1 - self.msg.motor_cmd[id].kp = self.kp[id.value] - self.msg.motor_cmd[id].kd = self.kd[id.value] - self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q - - # Initialize remote controller + self.subscribe_thread = None self.remote_controller = self.RemoteController() def _subscribe_motor_state(self): # polls robot state @ 250Hz while not self._shutdown_event.is_set(): start_time = time.time() + + # Step simulation if in simulation mode + if self.config.is_simulation and self.sim_env is not None: + self.sim_env.step() + msg = self.lowstate_subscriber.Read() if msg is not None: lowstate = G1_29_LowState() - # Capture motor states - for id in range(G1_29_Num_Motors): + # Capture motor states using jointindex + for id in G1_29_JointIndex: lowstate.motor_state[id].q = msg.motor_state[id].q lowstate.motor_state[id].dq = msg.motor_state[id].dq lowstate.motor_state[id].tau_est = msg.motor_state[id].tau_est @@ -207,7 +158,7 @@ def _subscribe_motor_state(self): # polls robot state @ 250Hz # Capture mode_machine lowstate.mode_machine = msg.mode_machine - self.lowstate_buffer.set_data(lowstate) + self._lowstate = lowstate current_time = time.time() all_t_elapsed = current_time - start_time @@ -216,7 +167,7 @@ def _subscribe_motor_state(self): # polls robot state @ 250Hz @cached_property def action_features(self) -> dict[str, type]: - return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex} + return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex} def calibrate(self) -> None: # robot is already calibrated pass @@ -225,20 +176,153 @@ def configure(self) -> None: pass def connect(self, calibrate: bool = True) -> None: # connect to DDS + from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_ + from unitree_sdk2py.idl.unitree_hg.msg.dds_ import ( + LowCmd_ as hg_LowCmd, + LowState_ as hg_LowState, + ) + from unitree_sdk2py.utils.crc import CRC + + # Initialize DDS channel and simulation environment if self.config.is_simulation: - self.ChannelFactoryInitialize(0, "lo") - self.mujoco_env = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True) + self._ChannelFactoryInitialize(0, "lo") + self._env_wrapper = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True) + # Extract the actual gym env from the dict structure + self.sim_env = self._env_wrapper["hub_env"][0].envs[0] else: - self.ChannelFactoryInitialize(0) + self._ChannelFactoryInitialize(0) + + # Initialize direct motor control interface + self.lowcmd_publisher = self._ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd) + self.lowcmd_publisher.Init() + self.lowstate_subscriber = self._ChannelSubscriber(kTopicLowState, hg_LowState) + self.lowstate_subscriber.Init() + + # Start subscribe thread to read robot state + self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state) + self.subscribe_thread.start() + + # Connect cameras + for cam in self._cameras.values(): + if not cam.is_connected: + cam.connect() + + logger.info(f"Connected {len(self._cameras)} camera(s).") + + # Initialize lowcmd message + self.crc = CRC() + self.msg = unitree_hg_msg_dds__LowCmd_() + self.msg.mode_pr = 0 + + # Wait for first state message to arrive + lowstate = None + while lowstate is None: + lowstate = self._lowstate + if lowstate is None: + time.sleep(0.01) + logger.warning("[UnitreeG1] Waiting for robot state...") + logger.warning("[UnitreeG1] Connected to robot.") + self.msg.mode_machine = lowstate.mode_machine + + # Initialize all motors with unified kp/kd from config + self.kp = np.array(self.config.kp, dtype=np.float32) + self.kd = np.array(self.config.kd, dtype=np.float32) + + for id in G1_29_JointIndex: + self.msg.motor_cmd[id].mode = 1 + self.msg.motor_cmd[id].kp = self.kp[id.value] + self.msg.motor_cmd[id].kd = self.kd[id.value] + self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q def disconnect(self): + # Signal thread to stop and unblock any waits self._shutdown_event.set() - self.subscribe_thread.join(timeout=2.0) - if self.config.is_simulation: - self.mujoco_env["hub_env"][0].envs[0].kill_sim() + + # Wait for subscribe thread to finish + if self.subscribe_thread is not None: + self.subscribe_thread.join(timeout=2.0) + if self.subscribe_thread.is_alive(): + logger.warning("Subscribe thread did not stop cleanly") + + # Close simulation environment + if self.config.is_simulation and self.sim_env is not None: + try: + # Force-kill the image publish subprocess first to avoid long waits + if hasattr(self.sim_env, "simulator") and hasattr(self.sim_env.simulator, "sim_env"): + sim_env_inner = self.sim_env.simulator.sim_env + if hasattr(sim_env_inner, "image_publish_process"): + proc = sim_env_inner.image_publish_process + if proc.process and proc.process.is_alive(): + logger.info("Force-terminating image publish subprocess...") + proc.stop_event.set() + proc.process.terminate() + proc.process.join(timeout=1) + if proc.process.is_alive(): + proc.process.kill() + self.sim_env.close() + except Exception as e: + logger.warning(f"Error closing sim_env: {e}") + self.sim_env = None + self._env_wrapper = None + + # Disconnect cameras + for cam in self._cameras.values(): + cam.disconnect() def get_observation(self) -> dict[str, Any]: - return self.lowstate_buffer.get_data() + lowstate = self._lowstate + if lowstate is None: + return {} + + obs = {} + + # Motors - q, dq, tau for all joints + for motor in G1_29_JointIndex: + name = motor.name + idx = motor.value + obs[f"{name}.q"] = lowstate.motor_state[idx].q + obs[f"{name}.dq"] = lowstate.motor_state[idx].dq + obs[f"{name}.tau"] = lowstate.motor_state[idx].tau_est + + # IMU - gyroscope + if lowstate.imu_state.gyroscope: + obs["imu.gyro.x"] = lowstate.imu_state.gyroscope[0] + obs["imu.gyro.y"] = lowstate.imu_state.gyroscope[1] + obs["imu.gyro.z"] = lowstate.imu_state.gyroscope[2] + + # IMU - accelerometer + if lowstate.imu_state.accelerometer: + obs["imu.accel.x"] = lowstate.imu_state.accelerometer[0] + obs["imu.accel.y"] = lowstate.imu_state.accelerometer[1] + obs["imu.accel.z"] = lowstate.imu_state.accelerometer[2] + + # IMU - quaternion + if lowstate.imu_state.quaternion: + obs["imu.quat.w"] = lowstate.imu_state.quaternion[0] + obs["imu.quat.x"] = lowstate.imu_state.quaternion[1] + obs["imu.quat.y"] = lowstate.imu_state.quaternion[2] + obs["imu.quat.z"] = lowstate.imu_state.quaternion[3] + + # IMU - rpy + if lowstate.imu_state.rpy: + obs["imu.rpy.roll"] = lowstate.imu_state.rpy[0] + obs["imu.rpy.pitch"] = lowstate.imu_state.rpy[1] + obs["imu.rpy.yaw"] = lowstate.imu_state.rpy[2] + + # Controller - parse wireless_remote and add to obs + if lowstate.wireless_remote and len(lowstate.wireless_remote) >= 24: + self.remote_controller.set(lowstate.wireless_remote) + obs["remote.buttons"] = self.remote_controller.button.copy() + obs["remote.lx"] = self.remote_controller.lx + obs["remote.ly"] = self.remote_controller.ly + obs["remote.rx"] = self.remote_controller.rx + obs["remote.ry"] = self.remote_controller.ry + + # Cameras - read images from ZMQ cameras + for cam_name, cam in self._cameras.items(): + obs[cam_name] = cam.async_read() + + return obs @property def is_calibrated(self) -> bool: @@ -246,11 +330,15 @@ def is_calibrated(self) -> bool: @property def is_connected(self) -> bool: - return self.lowstate_buffer.get_data() is not None + return self._lowstate is not None @property def _motors_ft(self) -> dict[str, type]: - return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex} + return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex} + + @property + def cameras(self) -> dict: + return self._cameras @property def _cameras_ft(self) -> dict[str, tuple]: @@ -293,39 +381,51 @@ def reset( self, control_dt: float | None = None, default_positions: list[float] | None = None, - ) -> None: # interpolate to default position + ) -> None: # move robot to default position if control_dt is None: control_dt = self.config.control_dt if default_positions is None: default_positions = np.array(self.config.default_positions, dtype=np.float32) - total_time = 3.0 - num_steps = int(total_time / control_dt) - - # get current state - robot_state = self.get_observation() + if self.config.is_simulation and self.sim_env is not None: + self.sim_env.reset() - # record current positions - init_dof_pos = np.zeros(29, dtype=np.float32) - for i in range(29): - init_dof_pos[i] = robot_state.motor_state[i].q + for motor in G1_29_JointIndex: + self.msg.motor_cmd[motor.value].q = default_positions[motor.value] + self.msg.motor_cmd[motor.value].qd = 0 + self.msg.motor_cmd[motor.value].kp = self.kp[motor.value] + self.msg.motor_cmd[motor.value].kd = self.kd[motor.value] + self.msg.motor_cmd[motor.value].tau = 0 + self.msg.crc = self.crc.Crc(self.msg) + self.lowcmd_publisher.Write(self.msg) + else: + total_time = 3.0 + num_steps = int(total_time / control_dt) - # Interpolate to default position - for step in range(num_steps): - start_time = time.time() + # get current state + obs = self.get_observation() - alpha = step / num_steps - action_dict = {} + # record current positions + init_dof_pos = np.zeros(29, dtype=np.float32) for motor in G1_29_JointIndex: - target_pos = default_positions[motor.value] - interp_pos = init_dof_pos[motor.value] * (1 - alpha) + target_pos * alpha - action_dict[f"{motor.name}.q"] = float(interp_pos) + init_dof_pos[motor.value] = obs[f"{motor.name}.q"] - self.send_action(action_dict) + # Interpolate to default position + for step in range(num_steps): + start_time = time.time() - # Maintain constant control rate - elapsed = time.time() - start_time - sleep_time = max(0, control_dt - elapsed) - time.sleep(sleep_time) + alpha = step / num_steps + action_dict = {} + for motor in G1_29_JointIndex: + target_pos = default_positions[motor.value] + interp_pos = init_dof_pos[motor.value] * (1 - alpha) + target_pos * alpha + action_dict[f"{motor.name}.q"] = float(interp_pos) + + self.send_action(action_dict) + + # Maintain constant control rate + elapsed = time.time() - start_time + sleep_time = max(0, control_dt - elapsed) + time.sleep(sleep_time) logger.info("Reached default position") diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 8eafa8e6da1..9a327d98653 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -74,6 +74,7 @@ ) from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401 from lerobot.configs import parser from lerobot.configs.policies import PreTrainedConfig from lerobot.datasets.image_writer import safe_stop_image_writer @@ -103,6 +104,7 @@ make_robot_from_config, omx_follower, so_follower, + unitree_g1, ) from lerobot.teleoperators import ( # noqa: F401 Teleoperator, @@ -508,6 +510,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset: (recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"] ): log_say("Reset the environment", cfg.play_sounds) + + # reset g1 robot + if robot.name == "unitree_g1": + robot.reset() + record_loop( robot=robot, events=events, diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index 8e0d9cf6d89..c16271932c0 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -60,6 +60,7 @@ make_robot_from_config, omx_follower, so_follower, + unitree_g1, ) from lerobot.utils.constants import ACTION from lerobot.utils.import_utils import register_third_party_plugins