diff --git a/.gitattributes b/.gitattributes index 44e16cf1d0..5bd931b188 100644 --- a/.gitattributes +++ b/.gitattributes @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - *.memmap filter=lfs diff=lfs merge=lfs -text *.stl filter=lfs diff=lfs merge=lfs -text *.safetensors filter=lfs diff=lfs merge=lfs -text *.mp4 filter=lfs diff=lfs merge=lfs -text *.arrow filter=lfs diff=lfs merge=lfs -text *.json !text !filter !merge !diff +tests/artifacts/cameras/*.{png,bag} filter=lfs diff=lfs merge=lfs -text diff --git a/lerobot/common/cameras/camera.py b/lerobot/common/cameras/camera.py index 5d9ac50942..d25bedb11b 100644 --- a/lerobot/common/cameras/camera.py +++ b/lerobot/common/cameras/camera.py @@ -1,25 +1,50 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import abc import numpy as np +from .configs import CameraConfig, ColorMode + +# NOTE(Steven): Consider something like configure() if makes sense for both cameras class Camera(abc.ABC): + def __init__(self, config: CameraConfig): + self.fps: int | None = config.fps + self.width: int | None = config.width + self.height: int | None = config.height + + @property @abc.abstractmethod - def connect(self): + def is_connected(self) -> bool: pass @abc.abstractmethod - def read(self, temporary_color: str | None = None) -> np.ndarray: + def connect(self, do_warmup_read: bool = True) -> None: pass @abc.abstractmethod - def async_read(self) -> np.ndarray: + def read(self, color_mode: ColorMode | None = None) -> np.ndarray: pass @abc.abstractmethod - def disconnect(self): + def async_read(self, timeout_ms: float = 2000) -> np.ndarray: pass - def __del__(self): - if getattr(self, "is_connected", False): - self.disconnect() + @abc.abstractmethod + def disconnect(self) -> None: + pass diff --git a/lerobot/common/cameras/configs.py b/lerobot/common/cameras/configs.py index 4c796a03de..1f8dda865b 100644 --- a/lerobot/common/cameras/configs.py +++ b/lerobot/common/cameras/configs.py @@ -1,11 +1,44 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import abc from dataclasses import dataclass +from enum import Enum import draccus -@dataclass +class ColorMode(Enum): + RGB = "rgb" + BGR = "bgr" + + +class Cv2Rotation(Enum): + NO_ROTATION = 0 + ROTATE_90 = 90 + ROTATE_180 = 180 + ROTATE_270 = -90 + + +@dataclass(kw_only=True) class CameraConfig(draccus.ChoiceRegistry, abc.ABC): + fps: int | None = None + width: int | None = None + height: int | None = None + @property def type(self) -> str: return self.get_choice_name(self.__class__) diff --git a/lerobot/common/cameras/intel/__init__.py b/lerobot/common/cameras/intel/__init__.py index d875ebf4e1..5786667c7c 100644 --- a/lerobot/common/cameras/intel/__init__.py +++ b/lerobot/common/cameras/intel/__init__.py @@ -1,4 +1,2 @@ from .camera_realsense import RealSenseCamera from .configuration_realsense import RealSenseCameraConfig - -__all__ = ["RealSenseCamera", "RealSenseCameraConfig"] diff --git a/lerobot/common/cameras/intel/camera_realsense.py b/lerobot/common/cameras/intel/camera_realsense.py index cd41286f25..274eda8420 100644 --- a/lerobot/common/cameras/intel/camera_realsense.py +++ b/lerobot/common/cameras/intel/camera_realsense.py @@ -13,523 +13,663 @@ # limitations under the License. """ -This file contains utilities for recording frames from Intel Realsense cameras. +Provides the RealSenseCamera class for capturing frames from Intel RealSense cameras. """ -import argparse -import concurrent.futures +import contextlib import logging import math -import shutil -import threading +import queue import time -import traceback -from collections import Counter -from pathlib import Path -from threading import Thread +from threading import Event, Thread +from typing import Any, Dict, List, Tuple, Union +import cv2 import numpy as np -from PIL import Image +import pyrealsense2 as rs from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError -from lerobot.common.utils.robot_utils import ( - busy_wait, -) from lerobot.common.utils.utils import capture_timestamp_utc from ..camera import Camera +from ..configs import ColorMode +from ..utils import get_cv2_rotation from .configuration_realsense import RealSenseCameraConfig -SERIAL_NUMBER_INDEX = 1 +logger = logging.getLogger(__name__) -def find_cameras(raise_when_empty=True, mock=False) -> list[dict]: - """ - Find the names and the serial numbers of the Intel RealSense cameras - connected to the computer. - """ - if mock: - import tests.cameras.mock_pyrealsense2 as rs - else: - import pyrealsense2 as rs - - cameras = [] - for device in rs.context().query_devices(): - serial_number = int(device.get_info(rs.camera_info(SERIAL_NUMBER_INDEX))) - name = device.get_info(rs.camera_info.name) - cameras.append( - { - "serial_number": serial_number, - "name": name, - } - ) - - if raise_when_empty and len(cameras) == 0: - raise OSError( - "Not a single camera was detected. Try re-plugging, or re-installing `librealsense` and its python wrapper `pyrealsense2`, or updating the firmware." - ) - - return cameras - - -def save_image(img_array, serial_number, frame_index, images_dir): - try: - img = Image.fromarray(img_array) - path = images_dir / f"camera_{serial_number}_frame_{frame_index:06d}.png" - path.parent.mkdir(parents=True, exist_ok=True) - img.save(str(path), quality=100) - logging.info(f"Saved image: {path}") - except Exception as e: - logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}") - - -def save_images_from_cameras( - images_dir: Path, - serial_numbers: list[int] | None = None, - fps=None, - width=None, - height=None, - record_time_s=2, - mock=False, -): - """ - Initializes all the cameras and saves images to the directory. Useful to visually identify the camera - associated to a given serial number. +class RealSenseCamera(Camera): """ - if serial_numbers is None or len(serial_numbers) == 0: - camera_infos = find_cameras(mock=mock) - serial_numbers = [cam["serial_number"] for cam in camera_infos] - - if mock: - import tests.cameras.mock_cv2 as cv2 - else: - import cv2 - - print("Connecting cameras") - cameras = [] - for cam_sn in serial_numbers: - print(f"{cam_sn=}") - config = RealSenseCameraConfig(serial_number=cam_sn, fps=fps, width=width, height=height, mock=mock) + Manages interactions with Intel RealSense cameras for frame and depth recording. + + This class provides an interface similar to `OpenCVCamera` but tailored for + RealSense devices, leveraging the `pyrealsense2` library. It uses the camera's + unique serial number for identification, offering more stability than device + indices, especially on Linux. It also supports capturing depth maps alongside + color frames. + + A `RealSenseCamera` instance requires a configuration object specifying the + camera's serial number or a unique device name. If using the name, ensure only + one camera with that name is connected. + + The camera's default settings (FPS, resolution, color mode) from the stream + profile are used unless overridden in the configuration. + + Args: + config (RealSenseCameraConfig): Configuration object containing settings like + serial number or name, desired FPS, width, height, color mode, rotation, + and whether to capture depth. + + Example: + ```python + from lerobot.common.cameras.intel.camera_realsense import RealSenseCamera + from lerobot.common.cameras.intel.configuration_realsense import RealSenseCameraConfig + from lerobot.common.cameras.configs import ColorMode + + # Basic usage with serial number + config = RealSenseCameraConfig(serial_number="1234567890") # Replace with actual SN camera = RealSenseCamera(config) - camera.connect() - print( - f"RealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.capture_width}, height={camera.capture_height}, color_mode={camera.color_mode})" + try: + camera.connect() + print(f"Connected to {camera}") + color_image = camera.read() # Synchronous read (color only) + print(f"Read frame shape: {color_image.shape}") + async_image = camera.async_read() # Asynchronous read + print(f"Async read frame shape: {async_image.shape}") + except Exception as e: + print(f"An error occurred: {e}") + finally: + camera.disconnect() + print(f"Disconnected from {camera}") + + # Example with depth capture and custom settings + custom_config = RealSenseCameraConfig( + serial_number="1234567890", # Replace with actual SN + fps=30, + width=1280, + height=720, + color_mode=ColorMode.BGR, # Request BGR output + rotation=0, + use_depth=True ) - cameras.append(camera) + depth_camera = RealSenseCamera(custom_config) + try: + depth_camera.connect() + color_image, depth_map = depth_camera.read() # Returns tuple + print(f"Color shape: {color_image.shape}, Depth shape: {depth_map.shape}") + finally: + depth_camera.disconnect() + + # Example using a unique camera name + name_config = RealSenseCameraConfig(name="Intel RealSense D435") # If unique + name_camera = RealSenseCamera(name_config) + # ... connect, read, disconnect ... + ``` + """ - images_dir = Path(images_dir) - if images_dir.exists(): - shutil.rmtree( - images_dir, - ) - images_dir.mkdir(parents=True, exist_ok=True) - - print(f"Saving images to {images_dir}") - frame_index = 0 - start_time = time.perf_counter() - try: - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: - while True: - now = time.perf_counter() - - for camera in cameras: - # If we use async_read when fps is None, the loop will go full speed, and we will end up - # saving the same images from the cameras multiple times until the RAM/disk is full. - image = camera.read() if fps is None else camera.async_read() - if image is None: - print("No Frame") - - bgr_converted_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - - executor.submit( - save_image, - bgr_converted_image, - camera.serial_number, - frame_index, - images_dir, - ) - - if fps is not None: - dt_s = time.perf_counter() - now - busy_wait(1 / fps - dt_s) - - if time.perf_counter() - start_time > record_time_s: - break - - print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}") - - frame_index += 1 - finally: - print(f"Images have been saved to {images_dir}") - for camera in cameras: - camera.disconnect() + def __init__(self, config: RealSenseCameraConfig): + """ + Initializes the RealSenseCamera instance. + Args: + config: The configuration settings for the camera. + """ -class RealSenseCamera(Camera): - """ - The RealSenseCamera class is similar to OpenCVCamera class but adds additional features for Intel Real Sense cameras: - - is instantiated with the serial number of the camera - won't randomly change as it can be the case of OpenCVCamera for Linux, - - can also be instantiated with the camera's name — if it's unique — using RealSenseCamera.init_from_name(), - - depth map can be returned. - - To find the camera indices of your cameras, you can run our utility script that will save a few frames for each camera: - ```bash - python lerobot/common/robot_devices/cameras/intelrealsense.py --images-dir outputs/images_from_intelrealsense_cameras - ``` - - When an RealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode - of the given camera will be used. - - Example of instantiating with a serial number: - ```python - from lerobot.common.robot_devices.cameras.configs import RealSenseCameraConfig - - config = RealSenseCameraConfig(serial_number=128422271347) - camera = RealSenseCamera(config) - camera.connect() - color_image = camera.read() - # when done using the camera, consider disconnecting - camera.disconnect() - ``` - - Example of instantiating with a name if it's unique: - ``` - config = RealSenseCameraConfig(name="Intel RealSense D405") - ``` - - Example of changing default fps, width, height and color_mode: - ```python - config = RealSenseCameraConfig(serial_number=128422271347, fps=30, width=1280, height=720) - config = RealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480) - config = RealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480, color_mode="bgr") - # Note: might error out upon `camera.connect()` if these settings are not compatible with the camera - ``` - - Example of returning depth: - ```python - config = RealSenseCameraConfig(serial_number=128422271347, use_depth=True) - camera = RealSenseCamera(config) - camera.connect() - color_image, depth_map = camera.read() - ``` - """ + super().__init__(config) - def __init__( - self, - config: RealSenseCameraConfig, - ): self.config = config - if config.name is not None: - self.serial_number = self.find_serial_number_from_name(config.name) + + if config.name is not None: # TODO(Steven): Do we want to continue supporting this? + self.serial_number = self._find_serial_number_from_name(config.name) + elif config.serial_number is not None: + self.serial_number = str(config.serial_number) else: - self.serial_number = config.serial_number + raise ValueError("RealSenseCameraConfig must provide either 'serial_number' or 'name'.") - # Store the raw (capture) resolution from the config. - self.capture_width = config.width - self.capture_height = config.height + self.fps: int | None = config.fps + self.channels: int = config.channels + self.color_mode: ColorMode = config.color_mode + self.use_depth: bool = config.use_depth - # If rotated by ±90, swap width and height. - if config.rotation in [-90, 90]: - self.width = config.height - self.height = config.width - else: - self.width = config.width - self.height = config.height - - self.fps = config.fps - self.channels = config.channels - self.color_mode = config.color_mode - self.use_depth = config.use_depth - self.force_hardware_reset = config.force_hardware_reset - self.mock = config.mock - - self.camera = None - self.is_connected = False - self.thread = None - self.stop_event = None - self.color_image = None - self.depth_map = None - self.logs = {} + self.rs_pipeline: rs.pipeline | None = None + self.rs_profile: rs.pipeline_profile | None = None - if self.mock: - import tests.cameras.mock_cv2 as cv2 - else: - import cv2 - - self.rotation = None - if config.rotation == -90: - self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE - elif config.rotation == 90: - self.rotation = cv2.ROTATE_90_CLOCKWISE - elif config.rotation == 180: - self.rotation = cv2.ROTATE_180 - - def find_serial_number_from_name(self, name): - camera_infos = find_cameras() - camera_names = [cam["name"] for cam in camera_infos] - this_name_count = Counter(camera_names)[name] - if this_name_count > 1: - # TODO(aliberts): Test this with multiple identical cameras (Aloha) + self.thread: Thread | None = None + self.stop_event: Event | None = None + self.frame_queue: queue.Queue = queue.Queue(maxsize=1) + + self.logs: dict = {} # For timestamping or other metadata + + self.rotation: int | None = get_cv2_rotation(config.rotation) + + # NOTE(Steven): What happens if rotation is specified but we leave width and height to None? + # NOTE(Steven): Should we enforce these parameters if rotation is set? + if self.height and self.width: + if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]: + self.prerotated_width, self.prerotated_height = self.height, self.width + else: + self.prerotated_width, self.prerotated_height = self.width, self.height + + def __str__(self) -> str: + """Returns a string representation of the camera instance.""" + return f"{self.__class__.__name__}({self.serial_number})" + + @property + def is_connected(self) -> bool: + """Checks if the camera pipeline is started and streams are active.""" + return self.rs_pipeline is not None and self.rs_profile is not None + + @staticmethod + def find_cameras(raise_when_empty: bool = True) -> List[Dict[str, Any]]: + """ + Detects available Intel RealSense cameras connected to the system. + + Args: + raise_when_empty (bool): If True, raises an OSError if no cameras are found. + + Returns: + List[Dict[str, Any]]: A list of dictionaries, + where each dictionary contains 'type', 'id' (serial number), 'name', + firmware version, USB type, and other available specs, and the default profile properties (width, height, fps, format). + + Raises: + OSError: If `raise_when_empty` is True and no cameras are detected, + or if pyrealsense2 is not installed. + ImportError: If pyrealsense2 is not installed. + """ + found_cameras_info = [] + context = rs.context() + devices = context.query_devices() + + if not devices: + logger.warning("No RealSense devices detected.") + if raise_when_empty: + raise OSError( + "No RealSense devices detected. Ensure cameras are connected, " + "library (`pyrealsense2`) is installed, and firmware is up-to-date." + ) + + for device in devices: + camera_info = { + "name": device.get_info(rs.camera_info.name), + "type": "RealSense", + "id": device.get_info(rs.camera_info.serial_number), + "firmware_version": device.get_info(rs.camera_info.firmware_version), + "usb_type_descriptor": device.get_info(rs.camera_info.usb_type_descriptor), + "physical_port": device.get_info(rs.camera_info.physical_port), + "product_id": device.get_info(rs.camera_info.product_id), + "product_line": device.get_info(rs.camera_info.product_line), + } + + # Get stream profiles for each sensor + sensors = device.query_sensors() + for sensor in sensors: + profiles = sensor.get_stream_profiles() + + for profile in profiles: + if profile.is_video_stream_profile() and profile.is_default(): + vprofile = profile.as_video_stream_profile() + stream_info = { + "stream_type": vprofile.stream_name(), + "format": vprofile.format().name, + "width": vprofile.width(), + "height": vprofile.height(), + "fps": vprofile.fps(), + } + camera_info["default_stream_profile"] = stream_info + + found_cameras_info.append(camera_info) + logger.debug(f"Found RealSense camera: {camera_info}") + + logger.info(f"Detected RealSense cameras: {[cam['id'] for cam in found_cameras_info]}") + return found_cameras_info + + def _find_serial_number_from_name(self, name: str) -> str: + """Finds the serial number for a given unique camera name.""" + camera_infos = self.find_cameras(raise_when_empty=True) + found_devices = [cam for cam in camera_infos if str(cam["name"]) == name] + + if not found_devices: + available_names = [cam["name"] for cam in camera_infos] raise ValueError( - f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them." + f"No RealSense camera found with name '{name}'. Available camera names: {available_names}" ) - name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos} - cam_sn = name_to_serial_dict[name] + if len(found_devices) > 1: + serial_numbers = [dev["serial_number"] for dev in found_devices] + raise ValueError( + f"Multiple RealSense cameras found with name '{name}'. " + f"Please use a unique serial number instead. Found SNs: {serial_numbers}" + ) - return cam_sn + serial_number = str(found_devices[0]["serial_number"]) + logger.info(f"Found serial number '{serial_number}' for camera name '{name}'.") + return serial_number - def connect(self): - if self.is_connected: - raise DeviceAlreadyConnectedError(f"RealSenseCamera({self.serial_number}) is already connected.") + def _configure_realsense_settings(self) -> rs.config: + """Creates and configures the RealSense pipeline configuration object.""" + rs_config = rs.config() + rs.config.enable_device(rs_config, self.serial_number) - if self.mock: - import tests.cameras.mock_pyrealsense2 as rs + if self.width and self.height and self.fps: + logger.debug( + f"Requesting Color Stream: {self.prerotated_width}x{self.prerotated_height} @ {self.fps} FPS, Format: {rs.format.rgb8}" + ) + rs_config.enable_stream( + rs.stream.color, self.prerotated_width, self.prerotated_height, rs.format.rgb8, self.fps + ) + if self.use_depth: + logger.debug( + f"Requesting Depth Stream: {self.prerotated_width}x{self.prerotated_height} @ {self.fps} FPS, Format: {rs.format.z16}" + ) + rs_config.enable_stream( + rs.stream.depth, self.prerotated_width, self.prerotated_height, rs.format.z16, self.fps + ) else: - import pyrealsense2 as rs + logger.debug(f"Requesting Color Stream: Default settings, Format: {rs.stream.color}") + rs_config.enable_stream(rs.stream.color) + if self.use_depth: + logger.debug(f"Requesting Depth Stream: Default settings, Format: {rs.stream.depth}") + rs_config.enable_stream(rs.stream.depth) - config = rs.config() - config.enable_device(str(self.serial_number)) + return rs_config - if self.fps and self.capture_width and self.capture_height: - # TODO(rcadene): can we set rgb8 directly? - config.enable_stream( - rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps - ) - else: - config.enable_stream(rs.stream.color) + def _validate_capture_settings(self) -> None: + """ + Validates if the actual stream settings match the requested configuration. + + This method compares the requested FPS, width, and height against the + actual settings obtained from the active RealSense profile after the + pipeline has started. + + Raises: + RuntimeError: If the actual camera settings significantly deviate + from the requested ones. + DeviceNotConnectedError: If the camera is not connected when attempting + to validate settings. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.") + + self._validate_fps(self.rs_profile.get_stream(rs.stream.color).as_video_stream_profile()) + self._validate_width_and_height(self.rs_profile.get_stream(rs.stream.color).as_video_stream_profile()) if self.use_depth: - if self.fps and self.capture_width and self.capture_height: - config.enable_stream( - rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps - ) - else: - config.enable_stream(rs.stream.depth) + self._validate_fps(self.rs_profile.get_stream(rs.stream.depth).as_video_stream_profile()) + self._validate_width_and_height( + self.rs_profile.get_stream(rs.stream.depth).as_video_stream_profile() + ) + + # NOTE(Steven): Add a wamr-up period time config + def connect(self, do_warmup_read: bool = True): + """ + Connects to the RealSense camera specified in the configuration. + + Initializes the RealSense pipeline, configures the required streams (color + and optionally depth), starts the pipeline, and validates the actual stream settings. + + Raises: + DeviceAlreadyConnectedError: If the camera is already connected. + ValueError: If the configuration is invalid (e.g., missing serial/name, name not unique). + ConnectionError: If the camera is found but fails to start the pipeline. + RuntimeError: If the pipeline starts but fails to apply requested settings. + OSError: If no RealSense devices are detected at all. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} is already connected.") + + logger.debug(f"Attempting to connect to camera {self.serial_number}...") + self.rs_pipeline = rs.pipeline() + rs_config = self._configure_realsense_settings() - self.camera = rs.pipeline() try: - profile = self.camera.start(config) - is_camera_open = True - except RuntimeError: - is_camera_open = False - traceback.print_exc() - - # If the camera doesn't work, display the camera indices corresponding to - # valid cameras. - if not is_camera_open: - # Verify that the provided `serial_number` is valid before printing the traceback - camera_infos = find_cameras() - serial_numbers = [cam["serial_number"] for cam in camera_infos] - if self.serial_number not in serial_numbers: - raise ValueError( - f"`serial_number` is expected to be one of these available cameras {serial_numbers}, but {self.serial_number} is provided instead. " - "To find the serial number you should use, run `python lerobot/common/robot_devices/cameras/intelrealsense.py`." - ) + self.rs_profile = self.rs_pipeline.start(rs_config) + logger.debug(f"Successfully started pipeline for camera {self.serial_number}.") + except RuntimeError as e: + self.rs_profile = None + self.rs_pipeline = None + raise ConnectionError( + f"Failed to open RealSense camera {self.serial_number}. Error: {e}. " + f"Run 'python -m find_cameras list-cameras' for details." + ) from e + + logger.debug(f"Validating stream configuration for {self.serial_number}...") + self._validate_capture_settings() + + if do_warmup_read: + logger.debug(f"Reading a warm-up frame for {self.serial_number}...") + self.read() # NOTE(Steven): For now we just read one frame, we could also loop for X secs + + logger.info(f"Camera {self.serial_number} connected and configured successfully.") + + def _validate_fps(self, stream) -> None: + """Validates and sets the internal FPS based on actual stream FPS.""" + + actual_fps = stream.fps() + + if self.fps is None: + self.fps = actual_fps + logger.info(f"FPS not specified, using camera default: {self.fps} FPS.") + return + + # Use math.isclose for robust float comparison + if not math.isclose(self.fps, actual_fps, rel_tol=1e-3): + logger.warning( + f"Requested FPS {self.fps} for {self}, but camera reported {actual_fps}. " + "This might be due to camera limitations." + ) + raise RuntimeError( + f"Failed to set requested FPS {self.fps} for {self}. Actual value reported: {actual_fps}." + ) + logger.debug(f"FPS set to {actual_fps} for {self}.") - raise OSError(f"Can't access RealSenseCamera({self.serial_number}).") + def _validate_width_and_height(self, stream) -> None: + """Validates and sets the internal capture width and height based on actual stream width.""" - color_stream = profile.get_stream(rs.stream.color) - color_profile = color_stream.as_video_stream_profile() - actual_fps = color_profile.fps() - actual_width = color_profile.width() - actual_height = color_profile.height() + actual_width = int(round(stream.width())) + actual_height = int(round(stream.height())) - # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30) - if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3): - # Using `OSError` since it's a broad that encompasses issues related to device communication - raise OSError( - f"Can't set {self.fps=} for RealSenseCamera({self.serial_number}). Actual value is {actual_fps}." + if self.width is None or self.height is None: + if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]: + self.width, self.height = actual_height, actual_width + self.prerotated_width, self.prerotated_height = actual_width, actual_height + else: + self.width, self.height = actual_width, actual_height + self.prerotated_width, self.prerotated_height = actual_width, actual_height + logger.info(f"Capture width set to camera default: {self.width}.") + logger.info(f"Capture height set to camera default: {self.height}.") + return + + if self.prerotated_width != actual_width: + logger.warning( + f"Requested capture width {self.prerotated_width} for {self}, but camera reported {actual_width}." + ) + raise RuntimeError( + f"Failed to set requested capture width {self.prerotated_width} for {self}. Actual value: {actual_width}." ) - if self.capture_width is not None and self.capture_width != actual_width: - raise OSError( - f"Can't set {self.capture_width=} for RealSenseCamera({self.serial_number}). Actual value is {actual_width}." + logger.debug(f"Capture width set to {actual_width} for {self}.") + + if self.prerotated_height != actual_height: + logger.warning( + f"Requested capture height {self.prerotated_height} for {self}, but camera reported {actual_height}." ) - if self.capture_height is not None and self.capture_height != actual_height: - raise OSError( - f"Can't set {self.capture_height=} for RealSenseCamera({self.serial_number}). Actual value is {actual_height}." + raise RuntimeError( + f"Failed to set requested capture height {self.prerotated_height} for {self}. Actual value: {actual_height}." ) + logger.debug(f"Capture height set to {actual_height} for {self}.") - self.fps = round(actual_fps) - self.capture_width = round(actual_width) - self.capture_height = round(actual_height) + def read_depth(self, timeout_ms: int = 5000) -> np.ndarray: + """ + Reads a single frame (depth) synchronously from the camera. - self.is_connected = True + This is a blocking call. It waits for a coherent set of frames (depth) + from the camera hardware via the RealSense pipeline. - def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]: - """Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3) - of type `np.uint8`, contrarily to the pytorch format which is float channel first. + Args: + timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 5000ms. - When `use_depth=True`, returns a tuple `(color_image, depth_map)` with a depth map in the format - height x width (e.g. 480 x 640) of type np.uint16. + Returns: + np.ndarray: The depth map as a NumPy array (height, width) + of type `np.uint16` (raw depth values in millimeters) and rotation. - Note: Reading a frame is done every `camera.fps` times per second, and it is blocking. - If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`. + Raises: + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If reading frames from the pipeline fails or frames are invalid. """ + if not self.is_connected: - raise DeviceNotConnectedError( - f"RealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first." + raise DeviceNotConnectedError(f"{self} is not connected.") + + if not self.use_depth: + raise RuntimeError( + f"Failed to capture depth frame from {self}. '.read_depth()'. Depth stream is not enabled." ) - if self.mock: - import tests.cameras.mock_cv2 as cv2 - else: - import cv2 + start_time = time.perf_counter() + + ret, frame = self.rs_pipeline.try_wait_for_frames( + timeout_ms=timeout_ms + ) # NOTE(Steven): This read has a timeout + + if not ret or frame is None: + raise RuntimeError( + f"Failed to capture frame from {self}. '.read_depth()' returned status={ret} and frame is None." + ) + + depth_frame = frame.get_depth_frame() + depth_map = np.asanyarray(depth_frame.get_data()) + + depth_map_processed = self._postprocess_image(depth_map) + + read_duration_ms = (time.perf_counter() - start_time) * 1e3 + logger.debug(f"{self} synchronous read took: {read_duration_ms:.1f}ms") + + self.logs["timestamp_utc"] = capture_timestamp_utc() + return depth_map_processed + + def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 5000) -> np.ndarray: + """ + Reads a single frame (color) synchronously from the camera. + + This is a blocking call. It waits for a coherent set of frames (color) + from the camera hardware via the RealSense pipeline. + + Args: + timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 5000ms. + + Returns: + np.ndarray: The captured color frame as a NumPy array + (height, width, channels), processed according to `color_mode` and rotation. + + Raises: + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If reading frames from the pipeline fails or frames are invalid. + ValueError: If an invalid `color_mode` is requested. + """ + + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") start_time = time.perf_counter() - frame = self.camera.wait_for_frames(timeout_ms=5000) + ret, frame = self.rs_pipeline.try_wait_for_frames( + timeout_ms=timeout_ms + ) # NOTE(Steven): This read has a timeout + + if not ret or frame is None: + raise RuntimeError( + f"Failed to capture frame from {self}. '.read()' returned status={ret} and frame is None." + ) color_frame = frame.get_color_frame() + color_image_raw = np.asanyarray(color_frame.get_data()) - if not color_frame: - raise OSError(f"Can't capture color image from RealSenseCamera({self.serial_number}).") + color_image_processed = self._postprocess_image(color_image_raw, color_mode) - color_image = np.asanyarray(color_frame.get_data()) + read_duration_ms = (time.perf_counter() - start_time) * 1e3 + logger.debug(f"{self} synchronous read took: {read_duration_ms:.1f}ms") - requested_color_mode = self.color_mode if temporary_color is None else temporary_color - if requested_color_mode not in ["rgb", "bgr"]: + self.logs["timestamp_utc"] = capture_timestamp_utc() + return color_image_processed + + def _postprocess_image(self, image: np.ndarray, color_mode: ColorMode | None = None) -> np.ndarray: + """ + Applies color conversion, dimension validation, and rotation to a raw color frame. + + Args: + image (np.ndarray): The raw image frame (expected RGB format from RealSense). + color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None, + uses the instance's default `self.color_mode`. + + Returns: + np.ndarray: The processed image frame according to `self.color_mode` and `self.rotation`. + + Raises: + ValueError: If the requested `color_mode` is invalid. + RuntimeError: If the raw frame dimensions do not match the configured + `width` and `height`. + """ + + if color_mode and color_mode not in (ColorMode.RGB, ColorMode.BGR): raise ValueError( - f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided." + f"Invalid requested color mode '{color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." ) - # IntelRealSense uses RGB format as default (red, green, blue). - if requested_color_mode == "bgr": - color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR) + h, w, c = image.shape - h, w, _ = color_image.shape - if h != self.capture_height or w != self.capture_width: - raise OSError( - f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead." + if h != self.prerotated_height or w != self.prerotated_width: + raise RuntimeError( + f"Captured frame dimensions ({h}x{w}) do not match configured capture dimensions ({self.prerotated_height}x{self.prerotated_width}) for {self}." + ) + if c != self.channels: + logger.warning( + f"Captured frame channels ({c}) do not match configured channels ({self.channels}) for {self}." ) - if self.rotation is not None: - color_image = cv2.rotate(color_image, self.rotation) + processed_image = image + if self.color_mode == ColorMode.BGR: + processed_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + logger.debug(f"Converted frame from RGB to BGR for {self}.") - # log the number of seconds it took to read the image - self.logs["delta_timestamp_s"] = time.perf_counter() - start_time + if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]: + processed_image = cv2.rotate(processed_image, self.rotation) + logger.debug(f"Rotated frame by {self.config.rotation} degrees for {self}.") - # log the utc time at which the image was received - self.logs["timestamp_utc"] = capture_timestamp_utc() + return processed_image - if self.use_depth: - depth_frame = frame.get_depth_frame() - if not depth_frame: - raise OSError(f"Can't capture depth image from RealSenseCamera({self.serial_number}).") + def _read_loop(self): + """ + Internal loop run by the background thread for asynchronous reading. - depth_map = np.asanyarray(depth_frame.get_data()) + Continuously reads frames (color and optional depth) using `read()` + and places the latest result (single image or tuple) into the `frame_queue`. + It overwrites any previous frame in the queue. + """ + logger.debug(f"Starting read loop thread for {self}.") + while not self.stop_event.is_set(): + try: + frame_data = self.read(timeout_ms=500) - h, w = depth_map.shape - if h != self.capture_height or w != self.capture_width: - raise OSError( - f"Can't capture depth map with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead." - ) + with contextlib.suppress(queue.Empty): + _ = self.frame_queue.get_nowait() + self.frame_queue.put(frame_data) + logger.debug(f"Frame data placed in queue for {self}.") - if self.rotation is not None: - depth_map = cv2.rotate(depth_map, self.rotation) + except DeviceNotConnectedError: + logger.error(f"Read loop for {self} stopped: Camera disconnected.") + break + except Exception as e: + logger.warning(f"Error reading frame in background thread for {self}: {e}") - return color_image, depth_map - else: - return color_image + logger.debug(f"Stopping read loop thread for {self}.") - def read_loop(self): - while not self.stop_event.is_set(): - if self.use_depth: - self.color_image, self.depth_map = self.read() - else: - self.color_image = self.read() + def _ensure_read_thread_running(self): + """Starts or restarts the background read thread if it's not running.""" + if self.thread is not None and self.thread.is_alive(): + self.thread.join(timeout=0.1) + if self.stop_event is not None: + self.stop_event.set() - def async_read(self): - """Access the latest color image""" + self.stop_event = Event() + self.thread = Thread( + target=self._read_loop, args=(), name=f"RealSenseReadLoop-{self}-{self.serial_number}" + ) + self.thread.daemon = True + self.thread.start() + logger.debug(f"Read thread started for {self}.") + + # NOTE(Steven): Missing implementation for depth + def async_read(self, timeout_ms: float = 2000) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + """ + Reads the latest available frame data (color or color+depth) asynchronously. + + This method retrieves the most recent frame captured by the background + read thread. It does not block waiting for the camera hardware directly, + only waits for a frame to appear in the internal queue up to the specified + timeout. + + Args: + timeout_ms (float): Maximum time in milliseconds to wait for a frame + to become available in the queue. Defaults to 2000ms (2 seconds). + + Returns: + Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + The latest captured frame data (color image or tuple of color image + and depth map), processed according to configuration. Format depends + on `self.use_depth`. + + Raises: + DeviceNotConnectedError: If the camera is not connected. + TimeoutError: If no frame data becomes available within the specified timeout. + RuntimeError: If the background thread died unexpectedly or another queue error occurs. + """ if not self.is_connected: - raise DeviceNotConnectedError( - f"RealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first." + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + self._ensure_read_thread_running() + + try: + return self.frame_queue.get(timeout=timeout_ms / 1000.0) + except queue.Empty as e: + thread_alive = self.thread is not None and self.thread.is_alive() + logger.error( + f"Timeout waiting for frame from {self} queue after {timeout_ms}ms. " + f"(Read thread alive: {thread_alive})" ) + raise TimeoutError( + f"Timed out waiting for frame from camera {self.serial_number} after {timeout_ms} ms. " + f"Read thread alive: {thread_alive}." + ) from e + except Exception as e: + logger.exception(f"Unexpected error getting frame data from queue for {self}: {e}") + raise RuntimeError( + f"Error getting frame data from queue for camera {self.serial_number}: {e}" + ) from e + + # NOTE(Steven): There are multiple functions that are the same between realsense and opencv. We should consider moving them to the parent class + def _shutdown_read_thread(self): + """Signals the background read thread to stop and waits for it to join.""" + if self.stop_event is not None: + logger.debug(f"Signaling stop event for read thread of {self}.") + self.stop_event.set() - if self.thread is None: - self.stop_event = threading.Event() - self.thread = Thread(target=self.read_loop, args=()) - self.thread.daemon = True - self.thread.start() - - num_tries = 0 - while self.color_image is None: - # TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here - num_tries += 1 - time.sleep(1 / self.fps) - if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()): - raise Exception( - "The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called." - ) + if self.thread is not None and self.thread.is_alive(): + logger.debug(f"Waiting for read thread of {self} to join...") + self.thread.join(timeout=2.0) + if self.thread.is_alive(): + logger.warning(f"Read thread for {self} did not terminate gracefully after 2 seconds.") + else: + logger.debug(f"Read thread for {self} joined successfully.") - if self.use_depth: - return self.color_image, self.depth_map - else: - return self.color_image + self.thread = None + self.stop_event = None def disconnect(self): - if not self.is_connected: + """ + Disconnects from the camera, stops the pipeline, and cleans up resources. + + Stops the background read thread (if running) and stops the RealSense pipeline. + + Raises: + DeviceNotConnectedError: If the camera is already disconnected (pipeline not running). + """ + + if not self.is_connected and self.thread is None: raise DeviceNotConnectedError( - f"RealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first." + f"Attempted to disconnect {self}, but it appears already disconnected." ) - if self.thread is not None and self.thread.is_alive(): - # wait for the thread to finish - self.stop_event.set() - self.thread.join() - self.thread = None - self.stop_event = None - - self.camera.stop() - self.camera = None - - self.is_connected = False - - def __del__(self): - if getattr(self, "is_connected", False): - self.disconnect() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Save a few frames using `RealSenseCamera` for all cameras connected to the computer, or a selected subset." - ) - parser.add_argument( - "--serial-numbers", - type=int, - nargs="*", - default=None, - help="List of serial numbers used to instantiate the `RealSenseCamera`. If not provided, find and use all available camera indices.", - ) - parser.add_argument( - "--fps", - type=int, - default=30, - help="Set the number of frames recorded per seconds for all cameras. If not provided, use the default fps of each camera.", - ) - parser.add_argument( - "--width", - type=int, - default=640, - help="Set the width for all cameras. If not provided, use the default width of each camera.", - ) - parser.add_argument( - "--height", - type=int, - default=480, - help="Set the height for all cameras. If not provided, use the default height of each camera.", - ) - parser.add_argument( - "--images-dir", - type=Path, - default="outputs/images_from_intelrealsense_cameras", - help="Set directory to save a few frames for each camera.", - ) - parser.add_argument( - "--record-time-s", - type=float, - default=2.0, - help="Set the number of seconds used to record the frames. By default, 2 seconds.", - ) - args = parser.parse_args() - save_images_from_cameras(**vars(args)) + logger.debug(f"Disconnecting from camera {self.serial_number}...") + + if self.thread is not None: + self._shutdown_read_thread() + + if self.rs_pipeline is not None: + logger.debug(f"Stopping RealSense pipeline object for {self}.") + self.rs_pipeline.stop() + self.rs_pipeline = None + self.rs_profile = None + + logger.info(f"Camera {self.serial_number} disconnected successfully.") diff --git a/lerobot/common/cameras/intel/configuration_realsense.py b/lerobot/common/cameras/intel/configuration_realsense.py index 66bb1b4f11..2cb1f42deb 100644 --- a/lerobot/common/cameras/intel/configuration_realsense.py +++ b/lerobot/common/cameras/intel/configuration_realsense.py @@ -14,7 +14,7 @@ from dataclasses import dataclass -from ..configs import CameraConfig +from ..configs import CameraConfig, ColorMode, Cv2Rotation @CameraConfig.register_subclass("intelrealsense") @@ -35,37 +35,32 @@ class RealSenseCameraConfig(CameraConfig): name: str | None = None serial_number: int | None = None - fps: int | None = None - width: int | None = None - height: int | None = None - color_mode: str = "rgb" - channels: int | None = None + color_mode: ColorMode = ColorMode.RGB + channels: int | None = 3 use_depth: bool = False - force_hardware_reset: bool = True - rotation: int | None = None - mock: bool = False + # NOTE(Steven): Check how draccus would deal with this str -> enum + rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION def __post_init__(self): - # bool is stronger than is None, since it works with empty strings - if bool(self.name) and bool(self.serial_number): + if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): raise ValueError( - f"One of them must be set: name or serial_number, but {self.name=} and {self.serial_number=} provided." + f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." ) - if self.color_mode not in ["rgb", "bgr"]: + if self.rotation not in ( + Cv2Rotation.NO_ROTATION, + Cv2Rotation.ROTATE_90, + Cv2Rotation.ROTATE_180, + Cv2Rotation.ROTATE_270, + ): raise ValueError( - f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided." + f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided." ) - self.channels = 3 + if self.channels != 3: + raise NotImplementedError(f"Unsupported number of channels: {self.channels}") - at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None - at_least_one_is_none = self.fps is None or self.width is None or self.height is None - if at_least_one_is_not_none and at_least_one_is_none: + if bool(self.name) and bool(self.serial_number): raise ValueError( - "For `fps`, `width` and `height`, either all of them need to be set, or none of them, " - f"but {self.fps=}, {self.width=}, {self.height=} were provided." + f"One of them must be set: name or serial_number, but {self.name=} and {self.serial_number=} provided." ) - - if self.rotation not in [-90, None, 90, 180]: - raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})") diff --git a/lerobot/common/cameras/opencv/__init__.py b/lerobot/common/cameras/opencv/__init__.py index edfd6df3d4..29ae5175d2 100644 --- a/lerobot/common/cameras/opencv/__init__.py +++ b/lerobot/common/cameras/opencv/__init__.py @@ -1,4 +1,2 @@ from .camera_opencv import OpenCVCamera from .configuration_opencv import OpenCVCameraConfig - -__all__ = ["OpenCVCamera", "OpenCVCameraConfig"] diff --git a/lerobot/common/cameras/opencv/camera_opencv.py b/lerobot/common/cameras/opencv/camera_opencv.py index d1b08743c6..f746bb346a 100644 --- a/lerobot/common/cameras/opencv/camera_opencv.py +++ b/lerobot/common/cameras/opencv/camera_opencv.py @@ -13,30 +13,28 @@ # limitations under the License. """ -This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring. +Provides the OpenCVCamera class for capturing frames from cameras using OpenCV. """ -import argparse -import concurrent.futures +import contextlib +import logging import math import platform -import shutil -import threading +import queue import time from pathlib import Path -from threading import Thread +from threading import Event, Thread +from typing import Any, Dict, List +import cv2 import numpy as np -from PIL import Image from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError -from lerobot.common.utils.robot_utils import ( - busy_wait, -) from lerobot.common.utils.utils import capture_timestamp_utc from ..camera import Camera -from .configuration_opencv import OpenCVCameraConfig +from ..utils import IndexOrPath, get_cv2_backend, get_cv2_rotation +from .configuration_opencv import ColorMode, OpenCVCameraConfig # The maximum opencv device index depends on your operating system. For instance, # if you have 3 cameras, they should be associated to index 0, 1, and 2. This is the case @@ -45,475 +43,514 @@ # treat the same cameras as new devices. Thus we select a higher bound to search indices. MAX_OPENCV_INDEX = 60 +logger = logging.getLogger(__name__) -def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]: - cameras = [] - if platform.system() == "Linux": - print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports") - possible_ports = [str(port) for port in Path("/dev").glob("video*")] - ports = _find_cameras(possible_ports, mock=mock) - for port in ports: - cameras.append( - { - "port": port, - "index": int(port.removeprefix("/dev/video")), - } - ) - else: - print( - "Mac or Windows detected. Finding available camera indices through " - f"scanning all indices from 0 to {MAX_OPENCV_INDEX}" - ) - possible_indices = range(max_index_search_range) - indices = _find_cameras(possible_indices, mock=mock) - for index in indices: - cameras.append( - { - "port": None, - "index": index, - } - ) - return cameras +class OpenCVCamera(Camera): + """ + Manages camera interactions using OpenCV for efficient frame recording. + This class provides a high-level interface to connect to, configure, and read + frames from cameras compatible with OpenCV's VideoCapture. It supports both + synchronous and asynchronous frame reading. -def _find_cameras( - possible_camera_ids: list[int | str], raise_when_empty=False, mock=False -) -> list[int | str]: - if mock: - import tests.cameras.mock_cv2 as cv2 - else: - import cv2 + An OpenCVCamera instance requires a camera index (e.g., 0) or a device path + (e.g., '/dev/video0' on Linux). Camera indices can be unstable across reboots + or port changes, especially on Linux. Use the provided utility script to find + available camera indices or paths: + ```bash + python -m lerobot.find_cameras + ``` + + The camera's default settings (FPS, resolution, color mode) are used unless + overridden in the configuration. - camera_ids = [] - for camera_idx in possible_camera_ids: - camera = cv2.VideoCapture(camera_idx) - is_open = camera.isOpened() - camera.release() + Args: + config (OpenCVCameraConfig): Configuration object containing settings like + camera index/path, desired FPS, width, height, color mode, and rotation. - if is_open: - print(f"Camera found at index {camera_idx}") - camera_ids.append(camera_idx) + Example: + ```python + from lerobot.common.cameras.opencv import OpenCVCamera + from lerobot.common.cameras.configuration_opencv import OpenCVCameraConfig, ColorMode - if raise_when_empty and len(camera_ids) == 0: - raise OSError( - "Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, " - "or your camera driver, or make sure your camera is compatible with opencv2." + # Basic usage with camera index 0 + config = OpenCVCameraConfig(index_or_path=0) + camera = OpenCVCamera(config) + try: + camera.connect() + print(f"Connected to {camera}") + color_image = camera.read() # Synchronous read + print(f"Read frame shape: {color_image.shape}") + async_image = camera.async_read() # Asynchronous read + print(f"Async read frame shape: {async_image.shape}") + except Exception as e: + print(f"An error occurred: {e}") + finally: + camera.disconnect() + print(f"Disconnected from {camera}") + + # Example with custom settings + custom_config = OpenCVCameraConfig( + index_or_path='/dev/video0', # Or use an index + fps=30, + width=1280, + height=720, + color_mode=ColorMode.RGB, + rotation=90 ) + custom_camera = OpenCVCamera(custom_config) + # ... connect, read, disconnect ... + ``` + """ - return camera_ids + def __init__(self, config: OpenCVCameraConfig): + """ + Initializes the OpenCVCamera instance. + Args: + config: The configuration settings for the camera. + """ + super().__init__(config) -def is_valid_unix_path(path: str) -> bool: - """Note: if 'path' points to a symlink, this will return True only if the target exists""" - p = Path(path) - return p.is_absolute() and p.exists() + self.config = config + self.index_or_path: IndexOrPath = config.index_or_path + self.fps: int | None = config.fps + self.channels: int = config.channels + self.color_mode: ColorMode = config.color_mode -def get_camera_index_from_unix_port(port: Path) -> int: - return int(str(port.resolve()).removeprefix("/dev/video")) + self.videocapture_camera: cv2.VideoCapture | None = None + self.thread: Thread | None = None + self.stop_event: Event | None = None + self.frame_queue: queue.Queue = queue.Queue(maxsize=1) -def save_image(img_array, camera_index, frame_index, images_dir): - img = Image.fromarray(img_array) - path = images_dir / f"camera_{camera_index:02d}_frame_{frame_index:06d}.png" - path.parent.mkdir(parents=True, exist_ok=True) - img.save(str(path), quality=100) + self.logs: dict = {} # NOTE(Steven): Might be removed in the future + self.rotation: int | None = get_cv2_rotation(config.rotation) + self.backend: int = get_cv2_backend() # NOTE(Steven): If I specify backend the opencv open fails -def save_images_from_cameras( - images_dir: Path, - camera_ids: list | None = None, - fps=None, - width=None, - height=None, - record_time_s=2, - mock=False, -): - """ - Initializes all the cameras and saves images to the directory. Useful to visually identify the camera - associated to a given camera index. - """ - if camera_ids is None or len(camera_ids) == 0: - camera_infos = find_cameras(mock=mock) - camera_ids = [cam["index"] for cam in camera_infos] - - print("Connecting cameras") - cameras = [] - for cam_idx in camera_ids: - config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock) - camera = OpenCVCamera(config) - camera.connect() - print( - f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.capture_width}, " - f"height={camera.capture_height}, color_mode={camera.color_mode})" - ) - cameras.append(camera) + # NOTE(Steven): What happens if rotation is specified but we leave width and height to None? + # NOTE(Steven): Should we enforce these parameters if rotation is set? + if self.height and self.width: + if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]: + self.prerotated_width, self.prerotated_height = self.height, self.width + else: + self.prerotated_width, self.prerotated_height = self.width, self.height - images_dir = Path(images_dir) - if images_dir.exists(): - shutil.rmtree( - images_dir, - ) - images_dir.mkdir(parents=True, exist_ok=True) - - print(f"Saving images to {images_dir}") - frame_index = 0 - start_time = time.perf_counter() - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: - while True: - now = time.perf_counter() - - for camera in cameras: - # If we use async_read when fps is None, the loop will go full speed, and we will endup - # saving the same images from the cameras multiple times until the RAM/disk is full. - image = camera.read() if fps is None else camera.async_read() - - executor.submit( - save_image, - image, - camera.camera_index, - frame_index, - images_dir, - ) - - if fps is not None: - dt_s = time.perf_counter() - now - busy_wait(1 / fps - dt_s) - - print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}") - - if time.perf_counter() - start_time > record_time_s: - break + def __str__(self) -> str: + """Returns a string representation of the camera instance.""" + return f"{self.__class__.__name__}({self.index_or_path})" - frame_index += 1 + @property + def is_connected(self) -> bool: + """Checks if the camera is currently connected and opened.""" + return isinstance(self.videocapture_camera, cv2.VideoCapture) and self.videocapture_camera.isOpened() - print(f"Images have been saved to {images_dir}") + def _configure_capture_settings(self) -> None: + """ + Applies the specified FPS, width, and height settings to the connected camera. + This method attempts to set the camera properties via OpenCV. It checks if + the camera successfully applied the settings and raises an error if not. -class OpenCVCamera(Camera): - """ - The OpenCVCamera class allows to efficiently record images from cameras. It relies on opencv2 to communicate - with the cameras. Most cameras are compatible. For more info, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html). + Args: + fps: The desired frames per second. If None, the setting is skipped. + width: The desired capture width. If None, the setting is skipped. + height: The desired capture height. If None, the setting is skipped. - An OpenCVCamera instance requires a camera index (e.g. `OpenCVCamera(camera_index=0)`). When you only have one camera - like a webcam of a laptop, the camera index is expected to be 0, but it might also be very different, and the camera index - might change if you reboot your computer or re-plug your camera. This behavior depends on your operation system. + Raises: + RuntimeError: If the camera fails to set any of the specified properties + to the requested value. + DeviceNotConnectedError: If the camera is not connected when attempting + to configure settings. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.") - To find the camera indices of your cameras, you can run our utility script that will be save a few frames for each camera: - ```bash - python lerobot/common/robot_devices/cameras/opencv.py --images-dir outputs/images_from_opencv_cameras - ``` + self._validate_fps() + self._validate_width_and_height() - When an OpenCVCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode - of the given camera will be used. + def connect(self, do_warmup_read: bool = True): + """ + Connects to the OpenCV camera specified in the configuration. - Example of usage: - ```python - from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig + Initializes the OpenCV VideoCapture object, sets desired camera properties + (FPS, width, height), and performs initial checks. - config = OpenCVCameraConfig(camera_index=0) - camera = OpenCVCamera(config) - camera.connect() - color_image = camera.read() - # when done using the camera, consider disconnecting - camera.disconnect() - ``` + Raises: + DeviceAlreadyConnectedError: If the camera is already connected. + ValueError: If the specified camera index/path is not found or accessible. + ConnectionError: If the camera is found but fails to open. + RuntimeError: If the camera opens but fails to apply requested FPS/resolution settings. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} is already connected.") - Example of changing default fps, width, height and color_mode: - ```python - config = OpenCVCameraConfig(camera_index=0, fps=30, width=1280, height=720) - config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480) - config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480, color_mode="bgr") - # Note: might error out open `camera.connect()` if these settings are not compatible with the camera - ``` - """ + # Use 1 thread for OpenCV operations to avoid potential conflicts or + # blocking in multi-threaded applications, especially during data collection. + cv2.setNumThreads(1) - def __init__(self, config: OpenCVCameraConfig): - self.config = config - self.camera_index = config.camera_index - self.port = None + logger.debug(f"Attempting to connect to camera {self.index_or_path} using backend {self.backend}...") + self.videocapture_camera = cv2.VideoCapture(self.index_or_path) - # Linux uses ports for connecting to cameras - if platform.system() == "Linux": - if isinstance(self.camera_index, int): - self.port = Path(f"/dev/video{self.camera_index}") - elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index): - self.port = Path(self.camera_index) - # Retrieve the camera index from a potentially symlinked path - self.camera_index = get_camera_index_from_unix_port(self.port) - else: - raise ValueError(f"Please check the provided camera_index: {self.camera_index}") + if not self.videocapture_camera.isOpened(): + self.videocapture_camera.release() + self.videocapture_camera = None + raise ConnectionError( + f"Failed to open OpenCV camera {self.index_or_path}." + f"Run 'python -m find_cameras list-cameras' for details." + ) - # Store the raw (capture) resolution from the config. - self.capture_width = config.width - self.capture_height = config.height + logger.debug(f"Successfully opened camera {self.index_or_path}. Applying configuration...") + self._configure_capture_settings() - # If rotated by ±90, swap width and height. - if config.rotation in [-90, 90]: - self.width = config.height - self.height = config.width - else: - self.width = config.width - self.height = config.height + if do_warmup_read: + logger.debug(f"Reading a warm-up frame for {self.index_or_path}...") + self.read() # NOTE(Steven): For now we just read one frame, we could also loop for X secs\ - self.fps = config.fps - self.channels = config.channels - self.color_mode = config.color_mode - self.mock = config.mock + logger.debug(f"Camera {self.index_or_path} connected and configured successfully.") - self.camera = None - self.is_connected = False - self.thread = None - self.stop_event = None - self.color_image = None - self.logs = {} + def _validate_fps(self) -> None: + """Validates and sets the camera's frames per second (FPS).""" - if self.mock: - import tests.cameras.mock_cv2 as cv2 - else: - import cv2 + if self.fps is None: + self.fps = self.videocapture_camera.get(cv2.CAP_PROP_FPS) + logger.info(f"FPS set to camera default: {self.fps}.") + return - self.rotation = None - if config.rotation == -90: - self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE - elif config.rotation == 90: - self.rotation = cv2.ROTATE_90_CLOCKWISE - elif config.rotation == 180: - self.rotation = cv2.ROTATE_180 + success = self.videocapture_camera.set(cv2.CAP_PROP_FPS, float(self.fps)) + actual_fps = self.videocapture_camera.get(cv2.CAP_PROP_FPS) + # Use math.isclose for robust float comparison + if not success or not math.isclose(self.fps, actual_fps, rel_tol=1e-3): + logger.warning( + f"Requested FPS {self.fps} for {self}, but camera reported {actual_fps} (set success: {success}). " + "This might be due to camera limitations." + ) + raise RuntimeError( + f"Failed to set requested FPS {self.fps} for {self}. Actual value reported: {actual_fps}." + ) + logger.debug(f"FPS set to {actual_fps} for {self}.") - def connect(self): - if self.is_connected: - raise DeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.") + def _validate_width_and_height(self) -> None: + """Validates and sets the camera's frame capture width and height.""" - if self.mock: - import tests.cameras.mock_cv2 as cv2 - else: - import cv2 - - # Use 1 thread to avoid blocking the main thread. Especially useful during data collection - # when other threads are used to save the images. - cv2.setNumThreads(1) - - backend = ( - cv2.CAP_V4L2 - if platform.system() == "Linux" - else cv2.CAP_DSHOW - if platform.system() == "Windows" - else cv2.CAP_AVFOUNDATION - if platform.system() == "Darwin" - else cv2.CAP_ANY - ) + actual_width = int(round(self.videocapture_camera.get(cv2.CAP_PROP_FRAME_WIDTH))) + actual_height = int(round(self.videocapture_camera.get(cv2.CAP_PROP_FRAME_HEIGHT))) - camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index - # First create a temporary camera trying to access `camera_index`, - # and verify it is a valid camera by calling `isOpened`. - tmp_camera = cv2.VideoCapture(camera_idx, backend) - is_camera_open = tmp_camera.isOpened() - # Release camera to make it accessible for `find_camera_indices` - tmp_camera.release() - del tmp_camera - - # If the camera doesn't work, display the camera indices corresponding to - # valid cameras. - if not is_camera_open: - # Verify that the provided `camera_index` is valid before printing the traceback - cameras_info = find_cameras() - available_cam_ids = [cam["index"] for cam in cameras_info] - if self.camera_index not in available_cam_ids: - raise ValueError( - f"`camera_index` is expected to be one of these available cameras {available_cam_ids}, but {self.camera_index} is provided instead. " - "To find the camera index you should use, run `python lerobot/common/robot_devices/cameras/opencv.py`." - ) - - raise OSError(f"Can't access OpenCVCamera({camera_idx}).") - - # Secondly, create the camera that will be used downstream. - # Note: For some unknown reason, calling `isOpened` blocks the camera which then - # needs to be re-created. - self.camera = cv2.VideoCapture(camera_idx, backend) - - if self.fps is not None: - self.camera.set(cv2.CAP_PROP_FPS, self.fps) - if self.capture_width is not None: - self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.capture_width) - if self.capture_height is not None: - self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.capture_height) - - actual_fps = self.camera.get(cv2.CAP_PROP_FPS) - actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH) - actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT) - - # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30) - if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3): - # Using `OSError` since it's a broad that encompasses issues related to device communication - raise OSError( - f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}." + # NOTE(Steven): When do we constraint the possibility of only setting one? + if self.width is None or self.height is None: + if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]: + self.width, self.height = actual_height, actual_width + self.prerotated_width, self.prerotated_height = actual_width, actual_height + else: + self.width, self.height = actual_width, actual_height + self.prerotated_width, self.prerotated_height = actual_width, actual_height + logger.info(f"Capture width set to camera default: {self.width}.") + logger.info(f"Capture height set to camera default: {self.height}.") + return + + success = self.videocapture_camera.set(cv2.CAP_PROP_FRAME_WIDTH, float(self.prerotated_width)) + if not success or self.prerotated_width != actual_width: + logger.warning( + f"Requested capture width {self.prerotated_width} for {self}, but camera reported {actual_width} (set success: {success})." + ) + raise RuntimeError( + f"Failed to set requested capture width {self.prerotated_width} for {self}. Actual value: {actual_width}." ) - if self.capture_width is not None and not math.isclose( - self.capture_width, actual_width, rel_tol=1e-3 - ): - raise OSError( - f"Can't set {self.capture_width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}." + logger.debug(f"Capture width set to {actual_width} for {self}.") + + success = self.videocapture_camera.set(cv2.CAP_PROP_FRAME_HEIGHT, float(self.prerotated_height)) + if not success or self.prerotated_height != actual_height: + logger.warning( + f"Requested capture height {self.prerotated_height} for {self}, but camera reported {actual_height} (set success: {success})." ) - if self.capture_height is not None and not math.isclose( - self.capture_height, actual_height, rel_tol=1e-3 - ): - raise OSError( - f"Can't set {self.capture_height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}." + raise RuntimeError( + f"Failed to set requested capture height {self.prerotated_height} for {self}. Actual value: {actual_height}." ) + logger.debug(f"Capture height set to {actual_height} for {self}.") - self.fps = round(actual_fps) - self.capture_width = round(actual_width) - self.capture_height = round(actual_height) - self.is_connected = True + @staticmethod + def find_cameras( + max_index_search_range=MAX_OPENCV_INDEX, raise_when_empty: bool = True + ) -> List[Dict[str, Any]]: + """ + Detects available OpenCV cameras connected to the system. - def read(self, temporary_color_mode: str | None = None) -> np.ndarray: - """Read a frame from the camera returned in the format (height, width, channels) - (e.g. 480 x 640 x 3), contrarily to the pytorch format which is channel first. + On Linux, it scans '/dev/video*' paths. On other systems (like macOS, Windows), + it checks indices from 0 up to `max_index_search_range`. - Note: Reading a frame is done every `camera.fps` times per second, and it is blocking. - If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`. + Args: + max_index_search_range (int): The maximum index to check on non-Linux systems. + raise_when_empty (bool): If True, raises an OSError if no cameras are found. + + Returns: + List[Dict[str, Any]]: A list of dictionaries, + where each dictionary contains 'type', 'id' (port index or path), + and the default profile properties (width, height, fps, format). """ - if not self.is_connected: - raise DeviceNotConnectedError( - f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first." + found_cameras_info = [] + + if platform.system() == "Linux": + logger.info("Linux detected. Scanning '/dev/video*' device paths...") + possible_paths = sorted(Path("/dev").glob("video*"), key=lambda p: p.name) + targets_to_scan = [str(p) for p in possible_paths] + logger.debug(f"Found potential paths: {targets_to_scan}") + else: + logger.info( + f"{platform.system()} system detected. Scanning indices from 0 to {max_index_search_range}..." ) + targets_to_scan = list(range(max_index_search_range)) + + for target in targets_to_scan: + camera = cv2.VideoCapture(target) + if camera.isOpened(): + default_width = int(camera.get(cv2.CAP_PROP_FRAME_WIDTH)) + default_height = int(camera.get(cv2.CAP_PROP_FRAME_HEIGHT)) + default_fps = camera.get(cv2.CAP_PROP_FPS) + default_format = camera.get(cv2.CAP_PROP_FORMAT) + camera_info = { + "name": f"OpenCV Camera @ {target}", + "type": "OpenCV", + "id": target, + "backend_api": camera.getBackendName(), + "default_stream_profile": { + "format": default_format, + "width": default_width, + "height": default_height, + "fps": default_fps, + }, + } - start_time = time.perf_counter() + found_cameras_info.append(camera_info) + logger.debug(f"Found OpenCV camera:: {camera_info}") + camera.release() - ret, color_image = self.camera.read() + if not found_cameras_info: + logger.warning("No OpenCV devices detected.") + if raise_when_empty: + raise OSError("No OpenCV devices detected. Ensure cameras are connected.") - if not ret: - raise OSError(f"Can't capture color image from camera {self.camera_index}.") + logger.info(f"Detected OpenCV cameras: {[cam['id'] for cam in found_cameras_info]}") + return found_cameras_info - requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode + def read(self, color_mode: ColorMode | None = None) -> np.ndarray: + """ + Reads a single frame synchronously from the camera. + + This is a blocking call. It waits for the next available frame from the + camera hardware via OpenCV. + + Args: + color_mode (Optional[ColorMode]): If specified, overrides the default + color mode (`self.color_mode`) for this read operation (e.g., + request RGB even if default is BGR). + + Returns: + np.ndarray: The captured frame as a NumPy array in the format + (height, width, channels), using the specified or default + color mode and applying any configured rotation. + + Raises: + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If reading the frame from the camera fails or if the + received frame dimensions don't match expectations before rotation. + ValueError: If an invalid `color_mode` is requested. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") - if requested_color_mode not in ["rgb", "bgr"]: - raise ValueError( - f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided." + start_time = time.perf_counter() + + # NOTE(Steven): Are we okay with this blocking an undefined amount of time? + ret, frame = self.videocapture_camera.read() + + if not ret or frame is None: + raise RuntimeError( + f"Failed to capture frame from {self}. '.read()' returned status={ret} and frame is None." ) - # OpenCV uses BGR format as default (blue, green, red) for all operations, including displaying images. - # However, Deep Learning framework such as LeRobot uses RGB format as default to train neural networks, - # so we convert the image color from BGR to RGB. - if requested_color_mode == "rgb": - if self.mock: - import tests.cameras.mock_cv2 as cv2 - else: - import cv2 + # Post-process the frame (color conversion, dimension check, rotation) + processed_frame = self._postprocess_image(frame, color_mode) + + read_duration_ms = (time.perf_counter() - start_time) * 1e3 + logger.debug(f"{self} synchronous read took: {read_duration_ms:.1f}ms") - color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB) + self.logs["timestamp_utc"] = capture_timestamp_utc() + return processed_frame - h, w, _ = color_image.shape - if h != self.capture_height or w != self.capture_width: - raise OSError( - f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead." + def _postprocess_image(self, image: np.ndarray, color_mode: ColorMode | None = None) -> np.ndarray: + """ + Applies color conversion, dimension validation, and rotation to a raw frame. + + Args: + image (np.ndarray): The raw image frame (expected BGR format from OpenCV). + color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None, + uses the instance's default `self.color_mode`. + + Returns: + np.ndarray: The processed image frame. + + Raises: + ValueError: If the requested `color_mode` is invalid. + RuntimeError: If the raw frame dimensions do not match the configured + `width` and `height`. + """ + requested_color_mode = self.color_mode if color_mode is None else color_mode + + if requested_color_mode not in (ColorMode.RGB, ColorMode.BGR): + raise ValueError( + f"Invalid requested color mode '{requested_color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." ) - if self.rotation is not None: - color_image = cv2.rotate(color_image, self.rotation) + h, w, c = image.shape - # log the number of seconds it took to read the image - self.logs["delta_timestamp_s"] = time.perf_counter() - start_time + if h != self.prerotated_height or w != self.prerotated_width: + raise RuntimeError( + f"Captured frame dimensions ({h}x{w}) do not match configured capture dimensions ({self.prerotated_height}x{self.prerotated_width}) for {self}." + ) + if c != self.channels: + logger.warning( + f"Captured frame channels ({c}) do not match configured channels ({self.channels}) for {self}." + ) - # log the utc time at which the image was received - self.logs["timestamp_utc"] = capture_timestamp_utc() + processed_image = image + if requested_color_mode == ColorMode.RGB: + processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + logger.debug(f"Converted frame from BGR to RGB for {self}.") - self.color_image = color_image + if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]: + processed_image = cv2.rotate(processed_image, self.rotation) + logger.debug(f"Rotated frame by {self.config.rotation} degrees for {self}.") - return color_image + return processed_image - def read_loop(self): + def _read_loop(self): + """ + Internal loop run by the background thread for asynchronous reading. + + Continuously reads frames from the camera using the synchronous `read()` + method and places the latest frame into the `frame_queue`. It overwrites + any previous frame in the queue. + """ + logger.debug(f"Starting read loop thread for {self}.") while not self.stop_event.is_set(): try: - self.color_image = self.read() + color_image = self.read() + + with contextlib.suppress(queue.Empty): + _ = self.frame_queue.get_nowait() + self.frame_queue.put(color_image) + logger.debug(f"Frame placed in queue for {self}.") + + except DeviceNotConnectedError: + logger.error(f"Read loop for {self} stopped: Camera disconnected.") + break except Exception as e: - print(f"Error reading in thread: {e}") + logger.warning(f"Error reading frame in background thread for {self}: {e}") + + logger.debug(f"Stopping read loop thread for {self}.") + + def _ensure_read_thread_running(self): + """Starts or restarts the background read thread if it's not running.""" + if self.thread is not None and self.thread.is_alive(): + self.thread.join(timeout=0.1) + if self.stop_event is not None: + self.stop_event.set() + + self.stop_event = Event() + self.thread = Thread( + target=self._read_loop, args=(), name=f"OpenCVCameraReadLoop-{self}-{self.index_or_path}" + ) + self.thread.daemon = True + self.thread.start() + logger.debug(f"Read thread started for {self}.") - def async_read(self): + def async_read(self, timeout_ms: float = 2000) -> np.ndarray: + """ + Reads the latest available frame asynchronously. + + This method retrieves the most recent frame captured by the background + read thread. It does not block waiting for the camera hardware directly, + only waits for a frame to appear in the internal queue up to the specified + timeout. + + Args: + timeout_ms (float): Maximum time in milliseconds to wait for a frame + to become available in the queue. Defaults to 2000ms (2 seconds). + + Returns: + np.ndarray: The latest captured frame as a NumPy array in the format + (height, width, channels), processed according to configuration. + + Raises: + DeviceNotConnectedError: If the camera is not connected. + TimeoutError: If no frame becomes available within the specified timeout. + RuntimeError: If an unexpected error occurs while retrieving from the queue. + """ if not self.is_connected: - raise DeviceNotConnectedError( - f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first." + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + self._ensure_read_thread_running() + + try: + return self.frame_queue.get(timeout=timeout_ms / 1000.0) + except queue.Empty as e: + thread_alive = self.thread is not None and self.thread.is_alive() + logger.error( + f"Timeout waiting for frame from {self} queue after {timeout_ms}ms. " + f"(Read thread alive: {thread_alive})" ) + raise TimeoutError( + f"Timed out waiting for frame from camera {self.index_or_path} after {timeout_ms} ms. " + f"Read thread alive: {thread_alive}." + ) from e + except Exception as e: + logger.exception(f"Unexpected error getting frame from queue for {self}: {e}") + raise RuntimeError(f"Error getting frame from queue for camera {self.index_or_path}: {e}") from e + + def _shutdown_read_thread(self): + """Signals the background read thread to stop and waits for it to join.""" + if self.stop_event is not None: + logger.debug(f"Signaling stop event for read thread of {self}.") + self.stop_event.set() - if self.thread is None: - self.stop_event = threading.Event() - self.thread = Thread(target=self.read_loop, args=()) - self.thread.daemon = True - self.thread.start() - - num_tries = 0 - while True: - if self.color_image is not None: - return self.color_image + if self.thread is not None and self.thread.is_alive(): + logger.debug(f"Waiting for read thread of {self} to join...") + self.thread.join(timeout=2.0) + if self.thread.is_alive(): + logger.warning(f"Read thread for {self} did not terminate gracefully after 2 seconds.") + else: + logger.debug(f"Read thread for {self} joined successfully.") - time.sleep(1 / self.fps) - num_tries += 1 - if num_tries > self.fps * 2: - raise TimeoutError("Timed out waiting for async_read() to start.") + self.thread = None + self.stop_event = None def disconnect(self): - if not self.is_connected: + """ + Disconnects from the camera and cleans up resources. + + Stops the background read thread (if running) and releases the OpenCV + VideoCapture object. + + Raises: + DeviceNotConnectedError: If the camera is already disconnected. + """ + if not self.is_connected and self.thread is None: raise DeviceNotConnectedError( - f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first." + f"Attempted to disconnect {self}, but it appears already disconnected." ) + logger.debug(f"Disconnecting from camera {self.index_or_path}...") + if self.thread is not None: - self.stop_event.set() - self.thread.join() # wait for the thread to finish - self.thread = None - self.stop_event = None - - self.camera.release() - self.camera = None - self.is_connected = False - - def __del__(self): - if getattr(self, "is_connected", False): - self.disconnect() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Save a few frames using `OpenCVCamera` for all cameras connected to the computer, or a selected subset." - ) - parser.add_argument( - "--camera-ids", - type=int, - nargs="*", - default=None, - help="List of camera indices used to instantiate the `OpenCVCamera`. If not provided, find and use all available camera indices.", - ) - parser.add_argument( - "--fps", - type=int, - default=None, - help="Set the number of frames recorded per seconds for all cameras. If not provided, use the default fps of each camera.", - ) - parser.add_argument( - "--width", - type=int, - default=None, - help="Set the width for all cameras. If not provided, use the default width of each camera.", - ) - parser.add_argument( - "--height", - type=int, - default=None, - help="Set the height for all cameras. If not provided, use the default height of each camera.", - ) - parser.add_argument( - "--images-dir", - type=Path, - default="outputs/images_from_opencv_cameras", - help="Set directory to save a few frames for each camera.", - ) - parser.add_argument( - "--record-time-s", - type=float, - default=4.0, - help="Set the number of seconds used to record the frames. By default, 2 seconds.", - ) - args = parser.parse_args() - save_images_from_cameras(**vars(args)) + self._shutdown_read_thread() + + if self.videocapture_camera is not None: + logger.debug(f"Releasing OpenCV VideoCapture object for {self}.") + self.videocapture_camera.release() + self.videocapture_camera = None + + logger.info(f"Camera {self.index_or_path} disconnected successfully.") diff --git a/lerobot/common/cameras/opencv/configuration_opencv.py b/lerobot/common/cameras/opencv/configuration_opencv.py index 983199bfda..ab9fc2f726 100644 --- a/lerobot/common/cameras/opencv/configuration_opencv.py +++ b/lerobot/common/cameras/opencv/configuration_opencv.py @@ -1,6 +1,7 @@ from dataclasses import dataclass +from pathlib import Path -from ..configs import CameraConfig +from ..configs import CameraConfig, ColorMode, Cv2Rotation @CameraConfig.register_subclass("opencv") @@ -8,7 +9,7 @@ class OpenCVCameraConfig(CameraConfig): """ Example of tested options for Intel Real Sense D405: - + #NOTE(Steven): update this doc ```python OpenCVCameraConfig(0, 30, 640, 480) OpenCVCameraConfig(0, 60, 640, 480) @@ -17,22 +18,26 @@ class OpenCVCameraConfig(CameraConfig): ``` """ - camera_index: int - fps: int | None = None - width: int | None = None - height: int | None = None - color_mode: str = "rgb" - channels: int | None = None - rotation: int | None = None - mock: bool = False + index_or_path: int | Path + color_mode: ColorMode = ColorMode.RGB + channels: int = 3 # NOTE(Steven): Why is this a config? + rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION def __post_init__(self): - if self.color_mode not in ["rgb", "bgr"]: + if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): raise ValueError( - f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided." + f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." ) - self.channels = 3 + if self.rotation not in ( + Cv2Rotation.NO_ROTATION, + Cv2Rotation.ROTATE_90, + Cv2Rotation.ROTATE_180, + Cv2Rotation.ROTATE_270, + ): + raise ValueError( + f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided." + ) - if self.rotation not in [-90, None, 90, 180]: - raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})") + if self.channels != 3: + raise NotImplementedError(f"Unsupported number of channels: {self.channels}") diff --git a/lerobot/common/cameras/utils.py b/lerobot/common/cameras/utils.py index f0b6ce5f27..30f5dd69ae 100644 --- a/lerobot/common/cameras/utils.py +++ b/lerobot/common/cameras/utils.py @@ -1,5 +1,30 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform +from pathlib import Path +from typing import TypeAlias + +import numpy as np +from PIL import Image + from .camera import Camera -from .configs import CameraConfig +from .configs import CameraConfig, Cv2Rotation + +IndexOrPath: TypeAlias = int | Path def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[str, Camera]: @@ -19,3 +44,30 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s raise ValueError(f"The motor type '{cfg.type}' is not valid.") return cameras + + +def get_cv2_rotation(rotation: Cv2Rotation) -> int: + import cv2 + + return { + Cv2Rotation.ROTATE_270: cv2.ROTATE_90_COUNTERCLOCKWISE, + Cv2Rotation.ROTATE_90: cv2.ROTATE_90_CLOCKWISE, + Cv2Rotation.ROTATE_180: cv2.ROTATE_180, + }.get(rotation) + + +def get_cv2_backend() -> int: + import cv2 + + return { + "Linux": cv2.CAP_DSHOW, + "Windows": cv2.CAP_AVFOUNDATION, + "Darwin": cv2.CAP_ANY, + }.get(platform.system(), cv2.CAP_V4L2) + + +def save_image(img_array: np.ndarray, camera_index: int, frame_index: int, images_dir: Path): + img = Image.fromarray(img_array) + path = images_dir / f"camera_{camera_index:02d}_frame_{frame_index:06d}.png" + path.parent.mkdir(parents=True, exist_ok=True) + img.save(str(path), quality=100) diff --git a/lerobot/common/robots/config.py b/lerobot/common/robots/config.py index 83a13ca937..3fd2872c7c 100644 --- a/lerobot/common/robots/config.py +++ b/lerobot/common/robots/config.py @@ -12,6 +12,17 @@ class RobotConfig(draccus.ChoiceRegistry, abc.ABC): # Directory to store calibration file calibration_dir: Path | None = None + def __post_init__(self): + if hasattr(self, "cameras"): + cameras = self.cameras + if cameras: + for cam_name, cam_config in cameras.items(): + for attr in ["width", "height", "fps"]: + if getattr(cam_config, attr) is None: + raise ValueError( + f"Camera config for '{cam_name}' has None value for required attribute '{attr}'" + ) + @property def type(self) -> str: return self.get_choice_name(self.__class__) diff --git a/lerobot/common/robots/stretch3/configuration_stretch3.py b/lerobot/common/robots/stretch3/configuration_stretch3.py index 47ddb54bb1..e80993a85f 100644 --- a/lerobot/common/robots/stretch3/configuration_stretch3.py +++ b/lerobot/common/robots/stretch3/configuration_stretch3.py @@ -19,7 +19,7 @@ class Stretch3RobotConfig(RobotConfig): cameras: dict[str, CameraConfig] = field( default_factory=lambda: { "navigation": OpenCVCameraConfig( - camera_index="/dev/hello-nav-head-camera", + index_or_path="/dev/hello-nav-head-camera", fps=10, width=1280, height=720, diff --git a/lerobot/common/utils/control_utils.py b/lerobot/common/utils/control_utils.py index 3ac792a5dc..730144f33e 100644 --- a/lerobot/common/utils/control_utils.py +++ b/lerobot/common/utils/control_utils.py @@ -38,6 +38,7 @@ from lerobot.common.utils.utils import get_safe_torch_device, has_method +# NOTE(Steven): Consider integrating this in camera class def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): log_items = [] if episode_index is not None: diff --git a/lerobot/find_cameras.py b/lerobot/find_cameras.py new file mode 100644 index 0000000000..cd3195caa5 --- /dev/null +++ b/lerobot/find_cameras.py @@ -0,0 +1,355 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import concurrent.futures +import logging +import shutil +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import numpy as np +from PIL import Image + +from lerobot.common.cameras.configs import ColorMode +from lerobot.common.cameras.intel.camera_realsense import RealSenseCamera +from lerobot.common.cameras.intel.configuration_realsense import RealSenseCameraConfig +from lerobot.common.cameras.opencv.camera_opencv import OpenCVCamera +from lerobot.common.cameras.opencv.configuration_opencv import OpenCVCameraConfig + +logger = logging.getLogger(__name__) +# logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(module)s - %(message)s") + + +def find_all_opencv_cameras() -> List[Dict[str, Any]]: + """ + Finds all available OpenCV cameras plugged into the system. + + Returns: + A list of all available OpenCV cameras with their metadata. + """ + all_opencv_cameras_info: List[Dict[str, Any]] = [] + logger.info("Searching for OpenCV cameras...") + try: + opencv_cameras = OpenCVCamera.find_cameras(raise_when_empty=False) + for cam_info in opencv_cameras: + all_opencv_cameras_info.append(cam_info) + logger.info(f"Found {len(opencv_cameras)} OpenCV cameras.") + except Exception as e: + logger.error(f"Error finding OpenCV cameras: {e}") + + return all_opencv_cameras_info + + +def find_all_realsense_cameras() -> List[Dict[str, Any]]: + """ + Finds all available RealSense cameras plugged into the system. + + Returns: + A list of all available RealSense cameras with their metadata. + """ + all_realsense_cameras_info: List[Dict[str, Any]] = [] + logger.info("Searching for RealSense cameras...") + try: + realsense_cameras = RealSenseCamera.find_cameras(raise_when_empty=False) + for cam_info in realsense_cameras: + all_realsense_cameras_info.append(cam_info) + logger.info(f"Found {len(realsense_cameras)} RealSense cameras.") + except ImportError: + logger.warning("Skipping RealSense camera search: pyrealsense2 library not found or not importable.") + except Exception as e: + logger.error(f"Error finding RealSense cameras: {e}") + + return all_realsense_cameras_info + + +def find_and_print_cameras(camera_type_filter: Optional[str] = None) -> List[Dict[str, Any]]: + """ + Finds available cameras based on an optional filter and prints their information. + + Args: + camera_type_filter: Optional string to filter cameras ("realsense" or "opencv"). + If None, lists all cameras. + + Returns: + A list of all available cameras matching the filter, with their metadata. + """ + all_cameras_info: List[Dict[str, Any]] = [] + + if camera_type_filter: + camera_type_filter = camera_type_filter.lower() + + if camera_type_filter is None or camera_type_filter == "opencv": + all_cameras_info.extend(find_all_opencv_cameras()) + if camera_type_filter is None or camera_type_filter == "realsense": + all_cameras_info.extend(find_all_realsense_cameras()) + + if not all_cameras_info: + if camera_type_filter: + logger.warning(f"No {camera_type_filter} cameras were detected.") + else: + logger.warning("No cameras (OpenCV or RealSense) were detected.") + else: + print("\n--- Detected Cameras ---") + for i, cam_info in enumerate(all_cameras_info): + print(f"Camera #{i + 1}:") + for key, value in cam_info.items(): + if key == "default_stream_profile" and isinstance(value, dict): + print(f" {key.replace('_', ' ').capitalize()}:") + for sub_key, sub_value in value.items(): + print(f" {sub_key.capitalize()}: {sub_value}") + else: + print(f" {key.replace('_', ' ').capitalize()}: {value}") + print("-" * 20) + return all_cameras_info + + +def save_image( + img_array: np.ndarray, + camera_identifier: Union[str, int], + images_dir: Path, + camera_type: str, +): + """ + Saves a single image to disk using Pillow. Handles color conversion if necessary. + """ + try: + img = Image.fromarray(img_array, mode="RGB") + + safe_identifier = str(camera_identifier).replace("/", "_").replace("\\", "_") + filename_prefix = f"{camera_type.lower()}_{safe_identifier}" + filename = f"{filename_prefix}.png" + + path = images_dir / filename + path.parent.mkdir(parents=True, exist_ok=True) + img.save(str(path)) + logger.info(f"Saved image: {path}") + except Exception as e: + logger.error(f"Failed to save image for camera {camera_identifier} (type {camera_type}): {e}") + + +def initialize_output_directory(output_dir: Union[str, Path]) -> Path: + """Initialize and clean the output directory.""" + output_dir = Path(output_dir) + if output_dir.exists(): + logger.info(f"Output directory {output_dir} exists. Removing previous content.") + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Saving images to {output_dir}") + return output_dir + + +def create_camera_instance(cam_meta: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Create and connect to a camera instance based on metadata.""" + cam_type = cam_meta.get("type") + cam_id = cam_meta.get("id") + instance = None + + logger.info(f"Preparing {cam_type} ID {cam_id} with default profile") + + try: + if cam_type == "OpenCV": + cv_config = OpenCVCameraConfig( + index_or_path=cam_id, + color_mode=ColorMode.RGB, + ) + instance = OpenCVCamera(cv_config) + elif cam_type == "RealSense": + rs_config = RealSenseCameraConfig( + serial_number=str(cam_id), + color_mode=ColorMode.RGB, + ) + instance = RealSenseCamera(rs_config) + else: + logger.warning(f"Unknown camera type: {cam_type} for ID {cam_id}. Skipping.") + return None + + if instance: + logger.info(f"Connecting to {cam_type} camera: {cam_id}...") + instance.connect() + return {"instance": instance, "meta": cam_meta} + except Exception as e: + logger.error(f"Failed to connect or configure {cam_type} camera {cam_id}: {e}") + if instance and instance.is_connected: + instance.disconnect() + return None + + +def process_camera_image( + cam_dict: Dict[str, Any], output_dir: Path, current_time: float +) -> Optional[concurrent.futures.Future]: + """Capture and process an image from a single camera.""" + cam = cam_dict["instance"] + meta = cam_dict["meta"] + cam_type_str = str(meta.get("type", "unknown")) + cam_id_str = str(meta.get("id", "unknown")) + + try: + image_data = cam.read() + + return save_image( + image_data, + cam_id_str, + output_dir, + cam_type_str, + ) + except TimeoutError: + logger.warning( + f"Timeout reading from {cam_type_str} camera {cam_id_str} at time {current_time:.2f}s." + ) + except Exception as e: + logger.error(f"Error reading from {cam_type_str} camera {cam_id_str}: {e}") + return None + + +def cleanup_cameras(cameras_to_use: List[Dict[str, Any]]): + """Disconnect all cameras.""" + logger.info(f"Disconnecting {len(cameras_to_use)} cameras...") + for cam_dict in cameras_to_use: + try: + if cam_dict["instance"] and cam_dict["instance"].is_connected: + cam_dict["instance"].disconnect() + except Exception as e: + logger.error(f"Error disconnecting camera {cam_dict['meta'].get('id')}: {e}") + + +def save_images_from_all_cameras( + output_dir: Union[str, Path], + record_time_s: float = 2.0, + camera_type_filter: Optional[str] = None, +): + """ + Connects to detected cameras (optionally filtered by type) and saves images from each. + Uses default stream profiles for width, height, and FPS. + + Args: + output_dir: Directory to save images. + record_time_s: Duration in seconds to record images. + camera_type_filter: Optional string to filter cameras ("realsense" or "opencv"). + If None, uses all detected cameras. + """ + output_dir = initialize_output_directory(output_dir) + all_camera_metadata = find_and_print_cameras(camera_type_filter=camera_type_filter) + + if not all_camera_metadata: + logger.warning("No cameras detected matching the criteria. Cannot save images.") + return + + # Create and connect to all cameras + cameras_to_use = [] + for cam_meta in all_camera_metadata: + camera_instance = create_camera_instance(cam_meta) + if camera_instance: + cameras_to_use.append(camera_instance) + + if not cameras_to_use: + logger.warning("No cameras could be connected. Aborting image save.") + return + + logger.info(f"Starting image capture for {record_time_s} seconds from {len(cameras_to_use)} cameras.") + start_time = time.perf_counter() + + # NOTE(Steven): This seems like an overkill to me + with concurrent.futures.ThreadPoolExecutor(max_workers=len(cameras_to_use) * 2) as executor: + try: + while time.perf_counter() - start_time < record_time_s: + futures = [] + current_capture_time = time.perf_counter() + + for cam_dict in cameras_to_use: + future = process_camera_image(cam_dict, output_dir, current_capture_time) + if future: + futures.append(future) + + if futures: + concurrent.futures.wait(futures) + + except KeyboardInterrupt: + logger.info("Capture interrupted by user.") + finally: + print("\nFinalizing image saving...") + executor.shutdown(wait=True) + cleanup_cameras(cameras_to_use) + logger.info(f"Image capture finished. Images saved to {output_dir}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Unified camera utility script for listing cameras and capturing images." + ) + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + # List cameras command + list_parser = subparsers.add_parser( + "list-cameras", help="Shows connected cameras. Optionally filter by type (realsense or opencv)." + ) + list_parser.add_argument( + "camera_type", + type=str, + nargs="?", + default=None, + choices=["realsense", "opencv"], + help="Specify camera type to list (e.g., 'realsense', 'opencv'). Lists all if omitted.", + ) + list_parser.set_defaults(func=lambda args: find_and_print_cameras(args.camera_type)) + + # Capture images command + capture_parser = subparsers.add_parser( + "capture-images", + help="Saves images from detected cameras (optionally filtered by type) using their default stream profiles.", + ) + capture_parser.add_argument( + "camera_type", + type=str, + nargs="?", + default=None, + choices=["realsense", "opencv"], + help="Specify camera type to capture from (e.g., 'realsense', 'opencv'). Captures from all if omitted.", + ) + capture_parser.add_argument( + "--output-dir", + type=Path, + default="outputs/captured_images", + help="Directory to save images. Default: outputs/captured_images", + ) + capture_parser.add_argument( + "--record-time-s", + type=float, + default=5.0, + help="Time duration to attempt capturing frames. Default: 0.5 seconds (usually enough for one frame).", + ) + capture_parser.set_defaults( + func=lambda args: save_images_from_all_cameras( + output_dir=args.output_dir, + record_time_s=args.record_time_s, + camera_type_filter=args.camera_type, + ) + ) + + args = parser.parse_args() + + if args.command is None: + default_output_dir = capture_parser.get_default("output_dir") + default_record_time_s = capture_parser.get_default("record_time_s") + + save_images_from_all_cameras( + output_dir=default_output_dir, + record_time_s=default_record_time_s, + camera_type_filter=None, + ) + else: + args.func(args) diff --git a/pyproject.toml b/pyproject.toml index 1578060eb2..381f5bbce7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,10 @@ dora = [ ] dynamixel = ["dynamixel-sdk>=3.7.31"] feetech = ["feetech-servo-sdk>=1.0.0"] -intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"] +intelrealsense = [ + "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'", + "pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'", #NOTE(Steven): Check previous version for sudo issue +] pi0 = ["transformers>=4.48.0"] pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"] stretch = [ diff --git a/tests/artifacts/cameras/fakecam_fullhd_480x270.png b/tests/artifacts/cameras/fakecam_fullhd_480x270.png new file mode 100644 index 0000000000..b564d5424c --- /dev/null +++ b/tests/artifacts/cameras/fakecam_fullhd_480x270.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f79d14daafb1c0cf2fec5d46ee8029a73fe357402fdd31a7cd4a4794d7319a7c +size 260367 diff --git a/tests/artifacts/cameras/fakecam_hd_320x180.png b/tests/artifacts/cameras/fakecam_hd_320x180.png new file mode 100644 index 0000000000..4cfd511a71 --- /dev/null +++ b/tests/artifacts/cameras/fakecam_hd_320x180.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8840fb643afe903191248703b1f95a57faf5812ecd9978ac502ee939646fdb2 +size 121115 diff --git a/tests/artifacts/cameras/fakecam_sd_160x120.png b/tests/artifacts/cameras/fakecam_sd_160x120.png new file mode 100644 index 0000000000..cdc681d183 --- /dev/null +++ b/tests/artifacts/cameras/fakecam_sd_160x120.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e11af87616b83c1cdb30330e951b91e86b51c64a1326e1ba5b4a3fbcdec1a11 +size 55698 diff --git a/tests/artifacts/cameras/fakecam_square_128x128.png b/tests/artifacts/cameras/fakecam_square_128x128.png new file mode 100644 index 0000000000..b117f49f27 --- /dev/null +++ b/tests/artifacts/cameras/fakecam_square_128x128.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9dc9df05797dc0e7b92edc845caab2e4c37c3cfcabb4ee6339c67212b5baba3b +size 38023 diff --git a/tests/artifacts/cameras/test_rs.bag b/tests/artifacts/cameras/test_rs.bag new file mode 100644 index 0000000000..1b9662c356 --- /dev/null +++ b/tests/artifacts/cameras/test_rs.bag @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8d6e64d6cb0e02c94ae125630ee758055bd2e695772c0463a30d63ddc6c5e17 +size 3520862 diff --git a/tests/cameras/mock_cv2.py b/tests/cameras/mock_cv2.py deleted file mode 100644 index eeaf859cc2..0000000000 --- a/tests/cameras/mock_cv2.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from functools import cache - -import numpy as np - -CAP_V4L2 = 200 -CAP_DSHOW = 700 -CAP_AVFOUNDATION = 1200 -CAP_ANY = -1 - -CAP_PROP_FPS = 5 -CAP_PROP_FRAME_WIDTH = 3 -CAP_PROP_FRAME_HEIGHT = 4 -COLOR_RGB2BGR = 4 -COLOR_BGR2RGB = 4 - -ROTATE_90_COUNTERCLOCKWISE = 2 -ROTATE_90_CLOCKWISE = 0 -ROTATE_180 = 1 - - -@cache -def _generate_image(width: int, height: int): - return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8) - - -def cvtColor(color_image, color_conversion): # noqa: N802 - if color_conversion in [COLOR_RGB2BGR, COLOR_BGR2RGB]: - return color_image[:, :, [2, 1, 0]] - else: - raise NotImplementedError(color_conversion) - - -def rotate(color_image, rotation): - if rotation is None: - return color_image - elif rotation == ROTATE_90_CLOCKWISE: - return np.rot90(color_image, k=1) - elif rotation == ROTATE_180: - return np.rot90(color_image, k=2) - elif rotation == ROTATE_90_COUNTERCLOCKWISE: - return np.rot90(color_image, k=3) - else: - raise NotImplementedError(rotation) - - -class VideoCapture: - def __init__(self, *args, **kwargs): - self._mock_dict = { - CAP_PROP_FPS: 30, - CAP_PROP_FRAME_WIDTH: 640, - CAP_PROP_FRAME_HEIGHT: 480, - } - self._is_opened = True - - def isOpened(self): # noqa: N802 - return self._is_opened - - def set(self, propId: int, value: float) -> bool: # noqa: N803 - if not self._is_opened: - raise RuntimeError("Camera is not opened") - self._mock_dict[propId] = value - return True - - def get(self, propId: int) -> float: # noqa: N803 - if not self._is_opened: - raise RuntimeError("Camera is not opened") - value = self._mock_dict[propId] - if value == 0: - if propId == CAP_PROP_FRAME_HEIGHT: - value = 480 - elif propId == CAP_PROP_FRAME_WIDTH: - value = 640 - return value - - def read(self): - if not self._is_opened: - raise RuntimeError("Camera is not opened") - h = self.get(CAP_PROP_FRAME_HEIGHT) - w = self.get(CAP_PROP_FRAME_WIDTH) - ret = True - return ret, _generate_image(width=w, height=h) - - def release(self): - self._is_opened = False - - def __del__(self): - if self._is_opened: - self.release() diff --git a/tests/cameras/mock_pyrealsense2.py b/tests/cameras/mock_pyrealsense2.py deleted file mode 100644 index c477eb0626..0000000000 --- a/tests/cameras/mock_pyrealsense2.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import enum - -import numpy as np - - -class stream(enum.Enum): # noqa: N801 - color = 0 - depth = 1 - - -class format(enum.Enum): # noqa: N801 - rgb8 = 0 - z16 = 1 - - -class config: # noqa: N801 - def enable_device(self, device_id: str): - self.device_enabled = device_id - - def enable_stream(self, stream_type: stream, width=None, height=None, color_format=None, fps=None): - self.stream_type = stream_type - # Overwrite default values when possible - self.width = 848 if width is None else width - self.height = 480 if height is None else height - self.color_format = format.rgb8 if color_format is None else color_format - self.fps = 30 if fps is None else fps - - -class RSColorProfile: - def __init__(self, config): - self.config = config - - def fps(self): - return self.config.fps - - def width(self): - return self.config.width - - def height(self): - return self.config.height - - -class RSColorStream: - def __init__(self, config): - self.config = config - - def as_video_stream_profile(self): - return RSColorProfile(self.config) - - -class RSProfile: - def __init__(self, config): - self.config = config - - def get_stream(self, color_format): - del color_format # unused - return RSColorStream(self.config) - - -class pipeline: # noqa: N801 - def __init__(self): - self.started = False - self.config = None - - def start(self, config): - self.started = True - self.config = config - return RSProfile(self.config) - - def stop(self): - if not self.started: - raise RuntimeError("You need to start the camera before stop.") - self.started = False - self.config = None - - def wait_for_frames(self, timeout_ms=50000): - del timeout_ms # unused - return RSFrames(self.config) - - -class RSFrames: - def __init__(self, config): - self.config = config - - def get_color_frame(self): - return RSColorFrame(self.config) - - def get_depth_frame(self): - return RSDepthFrame(self.config) - - -class RSColorFrame: - def __init__(self, config): - self.config = config - - def get_data(self): - data = np.ones((self.config.height, self.config.width, 3), dtype=np.uint8) - # Create a difference between rgb and bgr - data[:, :, 0] = 2 - return data - - -class RSDepthFrame: - def __init__(self, config): - self.config = config - - def get_data(self): - return np.ones((self.config.height, self.config.width), dtype=np.uint16) - - -class RSDevice: - def __init__(self): - pass - - def get_info(self, camera_info) -> str: - del camera_info # unused - # return fake serial number - return "123456789" - - -class context: # noqa: N801 - def __init__(self): - pass - - def query_devices(self): - return [RSDevice()] - - -class camera_info: # noqa: N801 - # fake name - name = "Intel RealSense D435I" - - def __init__(self, serial_number): - del serial_number - pass diff --git a/tests/cameras/test_cameras.py b/tests/cameras/test_cameras.py deleted file mode 100644 index 6dbc716c97..0000000000 --- a/tests/cameras/test_cameras.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Tests for physical cameras and their mocked versions. -If the physical camera is not connected to the computer, or not working, -the test will be skipped. - -Example of running a specific test: -```bash -pytest -sx tests/test_cameras.py::test_camera -``` - -Example of running test on a real camera connected to the computer: -```bash -pytest -sx 'tests/test_cameras.py::test_camera[opencv-False]' -pytest -sx 'tests/test_cameras.py::test_camera[intelrealsense-False]' -``` - -Example of running test on a mocked version of the camera: -```bash -pytest -sx 'tests/test_cameras.py::test_camera[opencv-True]' -pytest -sx 'tests/test_cameras.py::test_camera[intelrealsense-True]' -``` -""" - -import numpy as np -import pytest - -from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError -from tests.utils import TEST_CAMERA_TYPES, make_camera, require_camera - -# Maximum absolute difference between two consecutive images recorded by a camera. -# This value differs with respect to the camera. -MAX_PIXEL_DIFFERENCE = 25 - - -def compute_max_pixel_difference(first_image, second_image): - return np.abs(first_image.astype(float) - second_image.astype(float)).max() - - -@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES) -@require_camera -def test_camera(request, camera_type, mock): - """Test assumes that `camera.read()` returns the same image when called multiple times in a row. - So the environment should not change (you shouldnt be in front of the camera) and the camera should not be moving. - - Warning: The tests worked for a macbookpro camera, but I am getting assertion error (`np.allclose(color_image, async_color_image)`) - for my iphone camera and my LG monitor camera. - """ - # TODO(rcadene): measure fps in nightly? - # TODO(rcadene): test logs - - if camera_type == "opencv" and not mock: - pytest.skip("TODO(rcadene): fix test for opencv physical camera") - - camera_kwargs = {"camera_type": camera_type, "mock": mock} - - # Test instantiating - camera = make_camera(**camera_kwargs) - - # Test reading, async reading, disconnecting before connecting raises an error - with pytest.raises(DeviceNotConnectedError): - camera.read() - with pytest.raises(DeviceNotConnectedError): - camera.async_read() - with pytest.raises(DeviceNotConnectedError): - camera.disconnect() - - # Test deleting the object without connecting first - del camera - - # Test connecting - camera = make_camera(**camera_kwargs) - camera.connect() - assert camera.is_connected - assert camera.fps is not None - assert camera.capture_width is not None - assert camera.capture_height is not None - - # Test connecting twice raises an error - with pytest.raises(DeviceAlreadyConnectedError): - camera.connect() - - # Test reading from the camera - color_image = camera.read() - assert isinstance(color_image, np.ndarray) - assert color_image.ndim == 3 - h, w, c = color_image.shape - assert c == 3 - assert w > h - - # Test read and async_read outputs similar images - # ...warming up as the first frames can be black - for _ in range(30): - camera.read() - color_image = camera.read() - async_color_image = camera.async_read() - error_msg = ( - "max_pixel_difference between read() and async_read()", - compute_max_pixel_difference(color_image, async_color_image), - ) - # TODO(rcadene): properly set `rtol` - np.testing.assert_allclose( - color_image, async_color_image, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg - ) - - # Test disconnecting - camera.disconnect() - assert camera.camera is None - assert camera.thread is None - - # Test disconnecting with `__del__` - camera = make_camera(**camera_kwargs) - camera.connect() - del camera - - # Test acquiring a bgr image - camera = make_camera(**camera_kwargs, color_mode="bgr") - camera.connect() - assert camera.color_mode == "bgr" - bgr_color_image = camera.read() - np.testing.assert_allclose( - color_image, bgr_color_image[:, :, [2, 1, 0]], rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg - ) - del camera - - # Test acquiring a rotated image - camera = make_camera(**camera_kwargs) - camera.connect() - ori_color_image = camera.read() - del camera - - for rotation in [None, 90, 180, -90]: - camera = make_camera(**camera_kwargs, rotation=rotation) - camera.connect() - - if mock: - import tests.cameras.mock_cv2 as cv2 - else: - import cv2 - - if rotation is None: - manual_rot_img = ori_color_image - assert camera.rotation is None - elif rotation == 90: - manual_rot_img = np.rot90(color_image, k=1) - assert camera.rotation == cv2.ROTATE_90_CLOCKWISE - elif rotation == 180: - manual_rot_img = np.rot90(color_image, k=2) - assert camera.rotation == cv2.ROTATE_180 - elif rotation == -90: - manual_rot_img = np.rot90(color_image, k=3) - assert camera.rotation == cv2.ROTATE_90_COUNTERCLOCKWISE - - rot_color_image = camera.read() - - np.testing.assert_allclose( - rot_color_image, manual_rot_img, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg - ) - del camera - - # TODO(rcadene): Add a test for a camera that doesnt support fps=60 and raises an OSError - # TODO(rcadene): Add a test for a camera that supports fps=60 - - # Test width and height can be set - camera = make_camera(**camera_kwargs, fps=30, width=1280, height=720) - camera.connect() - assert camera.fps == 30 - assert camera.width == 1280 - assert camera.height == 720 - color_image = camera.read() - h, w, c = color_image.shape - assert h == 720 - assert w == 1280 - assert c == 3 - del camera - - # Test not supported width and height raise an error - camera = make_camera(**camera_kwargs, fps=30, width=0, height=0) - with pytest.raises(OSError): - camera.connect() - del camera - - -@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES) -@require_camera -def test_save_images_from_cameras(tmp_path, request, camera_type, mock): - # TODO(rcadene): refactor - if camera_type == "opencv": - from lerobot.common.cameras.opencv.camera_opencv import save_images_from_cameras - elif camera_type == "intelrealsense": - from lerobot.common.cameras.intel.camera_realsense import save_images_from_cameras - - # Small `record_time_s` to speedup unit tests - save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock) - - -@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES) -@require_camera -def test_camera_rotation(request, camera_type, mock): - config_kwargs = {"camera_type": camera_type, "mock": mock, "width": 640, "height": 480, "fps": 30} - - # No rotation. - camera = make_camera(**config_kwargs, rotation=None) - camera.connect() - assert camera.capture_width == 640 - assert camera.capture_height == 480 - assert camera.width == 640 - assert camera.height == 480 - no_rot_img = camera.read() - h, w, c = no_rot_img.shape - assert h == 480 and w == 640 and c == 3 - camera.disconnect() - - # Rotation = 90 (clockwise). - camera = make_camera(**config_kwargs, rotation=90) - camera.connect() - # With a 90° rotation, we expect the metadata dimensions to be swapped. - assert camera.capture_width == 640 - assert camera.capture_height == 480 - assert camera.width == 480 - assert camera.height == 640 - import cv2 - - assert camera.rotation == cv2.ROTATE_90_CLOCKWISE - rot_img = camera.read() - h, w, c = rot_img.shape - assert h == 640 and w == 480 and c == 3 - camera.disconnect() - - # Rotation = 180. - camera = make_camera(**config_kwargs, rotation=None) - camera.connect() - assert camera.capture_width == 640 - assert camera.capture_height == 480 - assert camera.width == 640 - assert camera.height == 480 - no_rot_img = camera.read() - h, w, c = no_rot_img.shape - assert h == 480 and w == 640 and c == 3 - camera.disconnect() diff --git a/tests/cameras/test_opencv.py b/tests/cameras/test_opencv.py new file mode 100644 index 0000000000..6f74a9696b --- /dev/null +++ b/tests/cameras/test_opencv.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example of running a specific test: +# ```bash +# pytest tests/cameras/test_opencv.py::test_connect +# ``` + +import os + +import numpy as np +import pytest + +from lerobot.common.cameras.configs import Cv2Rotation +from lerobot.common.cameras.opencv import OpenCVCamera, OpenCVCameraConfig +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +# NOTE(Steven): Consider improving the assert coverage +TEST_ARTIFACTS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "artifacts", "cameras") +DEFAULT_PNG_FILE_PATH = os.path.join(TEST_ARTIFACTS_DIR, "fakecam_sd_160x120.png") +TEST_IMAGE_PATHS = [ + os.path.join(TEST_ARTIFACTS_DIR, "fakecam_sd_160x120.png"), + os.path.join(TEST_ARTIFACTS_DIR, "fakecam_hd_320x180.png"), + os.path.join(TEST_ARTIFACTS_DIR, "fakecam_fullhd_480x270.png"), + os.path.join(TEST_ARTIFACTS_DIR, "fakecam_square_128x128.png"), +] + + +def test_base_class_implementation(): + config = OpenCVCameraConfig(index_or_path=0) + + _ = OpenCVCamera(config) + + +def test_connect(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) + camera = OpenCVCamera(config) + + camera.connect(do_warmup_read=False) + + assert camera.is_connected + + +def test_connect_already_connected(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) + camera = OpenCVCamera(config) + camera.connect(do_warmup_read=False) + + with pytest.raises(DeviceAlreadyConnectedError): + camera.connect(do_warmup_read=False) + + +def test_connect_invalid_camera_path(): + config = OpenCVCameraConfig(index_or_path="nonexistent/camera.png") + camera = OpenCVCamera(config) + + with pytest.raises(ConnectionError): + camera.connect(do_warmup_read=False) + + +def test_invalid_width_connect(): + config = OpenCVCameraConfig( + index_or_path=DEFAULT_PNG_FILE_PATH, + width=99999, # Invalid width to trigger error + height=480, + ) + camera = OpenCVCamera(config) + + with pytest.raises(RuntimeError): + camera.connect(do_warmup_read=False) + + +@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS) +def test_read(index_or_path): + config = OpenCVCameraConfig(index_or_path=index_or_path) + camera = OpenCVCamera(config) + camera.connect(do_warmup_read=False) + + img = camera.read() + + assert isinstance(img, np.ndarray) + + +def test_read_before_connect(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) + camera = OpenCVCamera(config) + + with pytest.raises(DeviceNotConnectedError): + _ = camera.read() + + +def test_disconnect(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) + camera = OpenCVCamera(config) + camera.connect(do_warmup_read=False) + + camera.disconnect() + + assert not camera.is_connected + + +def test_disconnect_before_connect(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) + camera = OpenCVCamera(config) + + with pytest.raises(DeviceNotConnectedError): + _ = camera.disconnect() + + +@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS) +def test_async_read(index_or_path): + config = OpenCVCameraConfig(index_or_path=index_or_path) + camera = OpenCVCamera(config) + camera.connect(do_warmup_read=False) + + img = camera.async_read() + + assert camera.thread is not None + assert camera.thread.is_alive() + assert isinstance(img, np.ndarray) + camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends + + +def test_async_read_timeout(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) + camera = OpenCVCamera(config) + camera.connect(do_warmup_read=False) + + with pytest.raises(TimeoutError): + camera.async_read(timeout_ms=0) + + camera.disconnect() + + +def test_async_read_before_connect(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) + camera = OpenCVCamera(config) + + with pytest.raises(DeviceNotConnectedError): + _ = camera.async_read() + + +@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS) +@pytest.mark.parametrize( + "rotation", + [ + Cv2Rotation.NO_ROTATION, + Cv2Rotation.ROTATE_90, + Cv2Rotation.ROTATE_180, + Cv2Rotation.ROTATE_270, + ], +) +def test_all_rotations(rotation, index_or_path): + filename = os.path.basename(index_or_path) + dimensions = filename.split("_")[-1].split(".")[0] # Assumes filenames format (_wxh.png) + original_width, original_height = map(int, dimensions.split("x")) + + config = OpenCVCameraConfig(index_or_path=index_or_path, rotation=rotation) + camera = OpenCVCamera(config) + camera.connect(do_warmup_read=False) + + img = camera.read() + assert isinstance(img, np.ndarray) + + if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270): + assert camera.width == original_height + assert camera.height == original_width + assert img.shape[:2] == (original_width, original_height) + else: + assert camera.width == original_width + assert camera.height == original_height + assert img.shape[:2] == (original_height, original_width) + + camera.disconnect() diff --git a/tests/cameras/test_realsense.py b/tests/cameras/test_realsense.py new file mode 100644 index 0000000000..201704d632 --- /dev/null +++ b/tests/cameras/test_realsense.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example of running a specific test: +# ```bash +# pytest tests/cameras/test_opencv.py::test_connect +# ``` + +import os +from unittest.mock import patch + +import numpy as np +import pytest + +from lerobot.common.cameras.configs import Cv2Rotation +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +try: + import pyrealsense2 as rs # noqa: F401 + + from lerobot.common.cameras.intel import RealSenseCamera, RealSenseCameraConfig +except (ImportError, ModuleNotFoundError): + pytest.skip("pyrealsense2 not available", allow_module_level=True) + +TEST_ARTIFACTS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "artifacts", "cameras") +BAG_FILE_PATH = os.path.join(TEST_ARTIFACTS_DIR, "test_rs.bag") + + +if not os.path.exists(BAG_FILE_PATH): + print(f"Warning: Bag file not found at {BAG_FILE_PATH}. Some tests might fail or be skipped.") + + +def mock_rs_config_enable_device_from_file(rs_config_instance, sn): + if not os.path.exists(BAG_FILE_PATH): + raise FileNotFoundError(f"Test bag file not found: {BAG_FILE_PATH}") + return rs_config_instance.enable_device_from_file(BAG_FILE_PATH, repeat_playback=True) + + +def mock_rs_config_enable_device_bad_file(rs_config_instance, sn): + return rs_config_instance.enable_device_from_file("non_existent_file.bag", repeat_playback=True) + + +def test_base_class_implementation(): + config = RealSenseCameraConfig(serial_number=42) + _ = RealSenseCamera(config) + + +@patch("pyrealsense2.config.enable_device", side_effect=mock_rs_config_enable_device_from_file) +def test_connect(mock_enable_device): + config = RealSenseCameraConfig(serial_number=42) + camera = RealSenseCamera(config) + + camera.connect(do_warmup_read=False) + assert camera.is_connected + + +@patch("pyrealsense2.config.enable_device", side_effect=mock_rs_config_enable_device_from_file) +def test_connect_already_connected(mock_enable_device): + config = RealSenseCameraConfig(serial_number=42) + camera = RealSenseCamera(config) + camera.connect(do_warmup_read=False) + + with pytest.raises(DeviceAlreadyConnectedError): + camera.connect(do_warmup_read=False) + + +@patch("pyrealsense2.config.enable_device", side_effect=mock_rs_config_enable_device_bad_file) +def test_connect_invalid_camera_path(mock_enable_device): + config = RealSenseCameraConfig(serial_number=42) + camera = RealSenseCamera(config) + + with pytest.raises(ConnectionError): + camera.connect(do_warmup_read=False) + + +@patch("pyrealsense2.config.enable_device", side_effect=mock_rs_config_enable_device_from_file) +def test_invalid_width_connect(mock_enable_device): + config = RealSenseCameraConfig(serial_number=42, width=99999, height=480, fps=30) + camera = RealSenseCamera(config) + + with pytest.raises(ConnectionError): + camera.connect(do_warmup_read=False) + + +@patch("pyrealsense2.config.enable_device", side_effect=mock_rs_config_enable_device_from_file) +def test_read(mock_enable_device): + config = RealSenseCameraConfig(serial_number=42, width=640, height=480, fps=30) + camera = RealSenseCamera(config) + camera.connect(do_warmup_read=False) + + img = camera.read() + assert isinstance(img, np.ndarray) + + +def test_read_before_connect(): + config = RealSenseCameraConfig(serial_number=42) + camera = RealSenseCamera(config) + + with pytest.raises(DeviceNotConnectedError): + _ = camera.read() + + +@patch("pyrealsense2.config.enable_device", side_effect=mock_rs_config_enable_device_from_file) +def test_disconnect(mock_enable_device): + config = RealSenseCameraConfig(serial_number=42) + camera = RealSenseCamera(config) + camera.connect(do_warmup_read=False) + + camera.disconnect() + + assert not camera.is_connected + + +def test_disconnect_before_connect(): + config = RealSenseCameraConfig(serial_number=42) + camera = RealSenseCamera(config) + + with pytest.raises(DeviceNotConnectedError): + camera.disconnect() + + +@patch("pyrealsense2.config.enable_device", side_effect=mock_rs_config_enable_device_from_file) +def test_async_read(mock_enable_device): + config = RealSenseCameraConfig(serial_number=42, width=640, height=480, fps=30) + camera = RealSenseCamera(config) + camera.connect(do_warmup_read=False) + + img = camera.async_read() + + assert camera.thread is not None + assert camera.thread.is_alive() + assert isinstance(img, np.ndarray) + camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends + + +@patch("pyrealsense2.config.enable_device", side_effect=mock_rs_config_enable_device_from_file) +def test_async_read_timeout(mock_enable_device): + config = RealSenseCameraConfig(serial_number=42, width=640, height=480, fps=30) + camera = RealSenseCamera(config) + camera.connect(do_warmup_read=False) + + with pytest.raises(TimeoutError): + camera.async_read(timeout_ms=0) + + camera.disconnect() + + +def test_async_read_before_connect(): + config = RealSenseCameraConfig(serial_number=42) + camera = RealSenseCamera(config) + + with pytest.raises(DeviceNotConnectedError): + _ = camera.async_read() + + +@pytest.mark.parametrize( + "rotation", + [ + Cv2Rotation.NO_ROTATION, + Cv2Rotation.ROTATE_90, + Cv2Rotation.ROTATE_180, + Cv2Rotation.ROTATE_270, + ], +) +@patch("pyrealsense2.config.enable_device", side_effect=mock_rs_config_enable_device_from_file) +def test_all_rotations(mock_enable_device, rotation): + config = RealSenseCameraConfig(serial_number=42, rotation=rotation) + camera = RealSenseCamera(config) + camera.connect(do_warmup_read=False) + + img = camera.read() + assert isinstance(img, np.ndarray) + + if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270): + assert camera.width == 480 + assert camera.height == 640 + assert img.shape[:2] == (640, 480) + else: + assert camera.width == 640 + assert camera.height == 480 + assert img.shape[:2] == (480, 640)